sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@
16
16
  import asyncio
17
17
  import copy
18
18
  import dataclasses
19
- import json
20
19
  import logging
21
20
  import math
22
21
  import os
@@ -25,7 +24,6 @@ import signal
25
24
  import sys
26
25
  import threading
27
26
  import time
28
- import uuid
29
27
  from collections import deque
30
28
  from contextlib import nullcontext
31
29
  from datetime import datetime
@@ -34,40 +32,33 @@ from http import HTTPStatus
34
32
  from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
35
33
 
36
34
  import fastapi
35
+ import orjson
37
36
  import torch
38
37
  import uvloop
39
38
  import zmq
40
39
  import zmq.asyncio
41
40
  from fastapi import BackgroundTasks
42
41
 
43
- from sglang.srt.aio_rwlock import RWLock
44
42
  from sglang.srt.configs.model_config import ModelConfig
45
43
  from sglang.srt.disaggregation.utils import DisaggregationMode
46
- from sglang.srt.hf_transformers_utils import (
47
- get_processor,
48
- get_tokenizer,
49
- get_tokenizer_from_processor,
50
- )
51
- from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
44
+ from sglang.srt.lora.lora_registry import LoRARegistry
52
45
  from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
53
46
  from sglang.srt.managers.disagg_service import start_disagg_service
54
47
  from sglang.srt.managers.io_struct import (
55
48
  AbortReq,
56
- BatchEmbeddingOut,
57
- BatchMultimodalOut,
58
- BatchStrOut,
59
- BatchTokenIDOut,
49
+ BaseReq,
50
+ BatchEmbeddingOutput,
51
+ BatchMultimodalOutput,
52
+ BatchStrOutput,
53
+ BatchTokenIDOutput,
60
54
  BatchTokenizedEmbeddingReqInput,
61
55
  BatchTokenizedGenerateReqInput,
62
- CloseSessionReqInput,
63
56
  ConfigureLoggingReq,
64
57
  EmbeddingReqInput,
65
58
  FreezeGCReq,
66
59
  GenerateReqInput,
67
60
  GetLoadReqInput,
68
61
  HealthCheckOutput,
69
- MultiTokenizerWrapper,
70
- OpenSessionReqInput,
71
62
  OpenSessionReqOutput,
72
63
  SessionParams,
73
64
  TokenizedEmbeddingReqInput,
@@ -84,6 +75,7 @@ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicat
84
75
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
85
76
  from sglang.srt.sampling.sampling_params import SamplingParams
86
77
  from sglang.srt.server_args import PortArgs, ServerArgs
78
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
87
79
  from sglang.srt.tracing.trace import (
88
80
  trace_get_proc_propagate_context,
89
81
  trace_req_finish,
@@ -96,10 +88,15 @@ from sglang.srt.utils import (
96
88
  dataclass_to_string_truncated,
97
89
  freeze_gc,
98
90
  get_bool_env_var,
99
- get_origin_rid,
100
91
  get_zmq_socket,
101
92
  kill_process_tree,
102
93
  )
94
+ from sglang.srt.utils.aio_rwlock import RWLock
95
+ from sglang.srt.utils.hf_transformers_utils import (
96
+ get_processor,
97
+ get_tokenizer,
98
+ get_tokenizer_from_processor,
99
+ )
103
100
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
104
101
 
105
102
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -158,11 +155,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
158
155
  self.log_requests = server_args.log_requests
159
156
  self.log_requests_level = server_args.log_requests_level
160
157
  self.preferred_sampling_params = (
161
- json.loads(server_args.preferred_sampling_params)
158
+ orjson.loads(server_args.preferred_sampling_params)
162
159
  if server_args.preferred_sampling_params
163
160
  else None
164
161
  )
165
162
  self.crash_dump_folder = server_args.crash_dump_folder
163
+ self.enable_trace = server_args.enable_trace
166
164
 
167
165
  # Read model args
168
166
  self.model_path = server_args.model_path
@@ -174,8 +172,19 @@ class TokenizerManager(TokenizerCommunicatorMixin):
174
172
  self.image_token_id = self.model_config.image_token_id
175
173
  self.max_req_input_len = None # Will be set later in engine.py
176
174
 
175
+ speculative_algorithm = SpeculativeAlgorithm.from_string(
176
+ server_args.speculative_algorithm
177
+ )
178
+ self.reserve_input_token_num = (
179
+ 0
180
+ if speculative_algorithm.is_none()
181
+ else server_args.speculative_num_draft_tokens
182
+ )
183
+ # Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
184
+ self.multi_item_delimiter_text = None
185
+
177
186
  if self.model_config.is_multimodal:
178
- import_processors()
187
+ import_processors("sglang.srt.multimodal.processors")
179
188
  try:
180
189
  _processor = get_processor(
181
190
  server_args.tokenizer_path,
@@ -214,6 +223,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
214
223
  self.processor = _processor
215
224
  self.tokenizer = get_tokenizer_from_processor(self.processor)
216
225
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
226
+ self._initialize_multi_item_delimiter_text()
217
227
  else:
218
228
  self.mm_processor = self.processor = None
219
229
 
@@ -226,6 +236,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
226
236
  trust_remote_code=server_args.trust_remote_code,
227
237
  revision=server_args.revision,
228
238
  )
239
+ self._initialize_multi_item_delimiter_text()
229
240
  # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
230
241
  if (
231
242
  server_args.enable_dynamic_batch_tokenizer
@@ -246,16 +257,25 @@ class TokenizerManager(TokenizerCommunicatorMixin):
246
257
  )
247
258
  if self.server_args.tokenizer_worker_num > 1:
248
259
  # Use tokenizer_worker_ipc_name in multi-tokenizer mode
249
- self.send_to_scheduler = get_zmq_socket(
260
+ send_to_scheduler = get_zmq_socket(
250
261
  context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
251
262
  )
263
+
264
+ class SenderWrapper:
265
+ def send_pyobj(self, obj):
266
+ if isinstance(obj, BaseReq):
267
+ obj.http_worker_ipc = port_args.tokenizer_ipc_name
268
+ send_to_scheduler.send_pyobj(obj)
269
+
270
+ # Make sure that each request carries the tokenizer_ipc_name for response routing
271
+ self.send_to_scheduler = SenderWrapper()
252
272
  else:
253
273
  self.send_to_scheduler = get_zmq_socket(
254
274
  context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
255
275
  )
256
276
 
257
277
  # Request states
258
- self.no_create_loop = False
278
+ self._chosen_loop = None
259
279
  self.rid_to_state: Dict[str, ReqState] = {}
260
280
  self.asyncio_tasks = set()
261
281
 
@@ -264,6 +284,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
264
284
  self.gracefully_exit = False
265
285
  self.last_receive_tstamp = 0
266
286
 
287
+ # Initial weights status
288
+ self.initial_weights_loaded = True
289
+ if server_args.checkpoint_engine_wait_weights_before_ready:
290
+ self.initial_weights_loaded = False
291
+
267
292
  # Dumping
268
293
  self.dump_requests_folder = "" # By default do not dump
269
294
  self.dump_requests_threshold = 1000
@@ -310,8 +335,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
310
335
  "model_name": self.server_args.served_model_name,
311
336
  # TODO: Add lora name/path in the future,
312
337
  }
313
- if server_args.tokenizer_metrics_allowed_customer_labels:
314
- for label in server_args.tokenizer_metrics_allowed_customer_labels:
338
+ if server_args.tokenizer_metrics_allowed_custom_labels:
339
+ for label in server_args.tokenizer_metrics_allowed_custom_labels:
315
340
  labels[label] = ""
316
341
  self.metrics_collector = TokenizerMetricsCollector(
317
342
  server_args=server_args,
@@ -330,10 +355,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
330
355
  [
331
356
  (
332
357
  (
333
- BatchStrOut,
334
- BatchEmbeddingOut,
335
- BatchTokenIDOut,
336
- BatchMultimodalOut,
358
+ BatchStrOutput,
359
+ BatchEmbeddingOutput,
360
+ BatchTokenIDOutput,
361
+ BatchMultimodalOutput,
337
362
  ),
338
363
  self._handle_batch_output,
339
364
  ),
@@ -346,7 +371,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
346
371
  (
347
372
  FreezeGCReq,
348
373
  lambda x: None,
349
- ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
374
+ ),
375
+ # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
350
376
  (HealthCheckOutput, lambda x: None),
351
377
  ]
352
378
  )
@@ -363,31 +389,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
363
389
  obj.normalize_batch_and_arguments()
364
390
 
365
391
  if self.server_args.tokenizer_worker_num > 1:
366
- # Modify rid, add worker_id
367
- if isinstance(obj.rid, list):
368
- # If it's an array, add worker_id prefix to each element
369
- obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
370
- else:
371
- # If it's a single value, add worker_id prefix
372
- obj.rid = f"{self.worker_id}_{obj.rid}"
392
+ from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker
373
393
 
374
- if obj.is_single:
375
- bootstrap_room = (
376
- obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
377
- )
378
- trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
379
- trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
380
- else:
381
- for i in range(len(obj.rid)):
382
- bootstrap_room = (
383
- obj.bootstrap_room[i]
384
- if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
385
- else None
386
- )
387
- trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
388
- trace_slice_start(
389
- "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
390
- )
394
+ assert isinstance(self, TokenizerWorker)
395
+ self._attach_multi_http_worker_info(obj)
396
+
397
+ if self.enable_trace:
398
+ self._trace_request_start(obj, created_time)
391
399
 
392
400
  if self.log_requests:
393
401
  max_length, skip_names, _ = self.log_request_metadata
@@ -588,9 +596,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
588
596
  )
589
597
 
590
598
  if self.mm_processor and obj.contains_mm_input():
591
- if not isinstance(obj.image_data, list):
599
+ if obj.image_data is not None and not isinstance(obj.image_data, list):
592
600
  obj.image_data = [obj.image_data]
593
- if not isinstance(obj.audio_data, list):
601
+ if obj.audio_data is not None and not isinstance(obj.audio_data, list):
594
602
  obj.audio_data = [obj.audio_data]
595
603
  mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
596
604
  image_data=obj.image_data,
@@ -618,6 +626,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
618
626
  _max_req_len = self.context_len
619
627
 
620
628
  input_token_num = len(input_ids) if input_ids is not None else 0
629
+ input_token_num += self.reserve_input_token_num
621
630
  if input_token_num >= self.context_len:
622
631
  if self.server_args.allow_auto_truncate:
623
632
  logger.warning(
@@ -719,7 +728,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
719
728
  )
720
729
 
721
730
  tokenized_obj = TokenizedGenerateReqInput(
722
- obj.rid,
723
731
  input_text,
724
732
  input_ids,
725
733
  mm_inputs,
@@ -729,6 +737,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
729
737
  obj.top_logprobs_num,
730
738
  obj.token_ids_logprob,
731
739
  obj.stream,
740
+ rid=obj.rid,
741
+ http_worker_ipc=obj.http_worker_ipc,
732
742
  bootstrap_host=obj.bootstrap_host,
733
743
  bootstrap_port=obj.bootstrap_port,
734
744
  bootstrap_room=obj.bootstrap_room,
@@ -738,15 +748,19 @@ class TokenizerManager(TokenizerCommunicatorMixin):
738
748
  custom_logit_processor=obj.custom_logit_processor,
739
749
  return_hidden_states=obj.return_hidden_states,
740
750
  data_parallel_rank=obj.data_parallel_rank,
751
+ priority=obj.priority,
752
+ extra_key=obj.extra_key,
741
753
  )
742
754
  elif isinstance(obj, EmbeddingReqInput):
743
755
  tokenized_obj = TokenizedEmbeddingReqInput(
744
- obj.rid,
745
756
  input_text,
746
757
  input_ids,
747
758
  mm_inputs,
748
759
  token_type_ids,
749
760
  sampling_params,
761
+ rid=obj.rid,
762
+ priority=obj.priority,
763
+ http_worker_ipc=obj.http_worker_ipc,
750
764
  )
751
765
 
752
766
  return tokenized_obj
@@ -757,6 +771,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
757
771
  """Handle batch tokenization for text inputs only."""
758
772
  logger.debug(f"Starting batch tokenization for {batch_size} text requests")
759
773
 
774
+ # If batch does not have text nothing to tokenize
775
+ # so lets construct the return object
776
+ if not self._batch_has_text(batch_size, obj):
777
+ # All requests already have input_ids, no need to tokenize
778
+ return [await self._tokenize_one_request(obj[i]) for i in range(batch_size)]
779
+
780
+ self._validate_batch_tokenization_constraints(batch_size, obj)
781
+
760
782
  # Collect requests and texts
761
783
  requests = [obj[i] for i in range(batch_size)]
762
784
  texts = [req.text for req in requests]
@@ -806,6 +828,30 @@ class TokenizerManager(TokenizerCommunicatorMixin):
806
828
  "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
807
829
  )
808
830
 
831
+ def _batch_has_text(
832
+ self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
833
+ ) -> bool:
834
+ """Check if any request in the batch contains text input."""
835
+ for i in range(batch_size):
836
+ if obj[i].text:
837
+ return True
838
+ elif self.is_generation and obj[i].contains_mm_input():
839
+ return True
840
+
841
+ return False
842
+
843
+ def _should_use_batch_tokenization(self, batch_size, requests) -> bool:
844
+ """Return True if we should run the tokenizer in batch mode.
845
+
846
+ Current policy:
847
+ - Respect explicit server flag `enable_tokenizer_batch_encode`.
848
+ - Or, if no request has text or multimodal input (all use pre-tokenized input_ids or input_embeds), batch the requests without tokenization.
849
+ """
850
+ return batch_size > 0 and (
851
+ self.server_args.enable_tokenizer_batch_encode
852
+ or not self._batch_has_text(batch_size, requests)
853
+ )
854
+
809
855
  def _send_one_request(
810
856
  self,
811
857
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -940,13 +986,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
940
986
  generators = []
941
987
  rids = []
942
988
  if getattr(obj, "parallel_sample_num", 1) == 1:
943
- if self.server_args.enable_tokenizer_batch_encode:
944
- # Validate batch tokenization constraints
945
- self._validate_batch_tokenization_constraints(batch_size, obj)
946
-
989
+ if self._should_use_batch_tokenization(batch_size, obj):
947
990
  tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
948
-
949
- # Send as a single batched request
950
991
  self._send_batch_request(obj, tokenized_objs, created_time)
951
992
 
952
993
  # Set up generators for each request in the batch
@@ -1038,10 +1079,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1038
1079
  def abort_request(self, rid: str = "", abort_all: bool = False):
1039
1080
  if not abort_all and rid not in self.rid_to_state:
1040
1081
  return
1041
- req = AbortReq(rid, abort_all)
1082
+ req = AbortReq(rid=rid, abort_all=abort_all)
1042
1083
  self.send_to_scheduler.send_pyobj(req)
1043
1084
  if self.enable_metrics:
1044
- self.metrics_collector.observe_one_aborted_request()
1085
+ # TODO: also use custom_labels from the request
1086
+ self.metrics_collector.observe_one_aborted_request(
1087
+ self.metrics_collector.labels
1088
+ )
1045
1089
 
1046
1090
  async def pause_generation(self):
1047
1091
  async with self.is_pause_cond:
@@ -1077,8 +1121,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1077
1121
  async def _wait_for_model_update_from_disk(
1078
1122
  self, obj: UpdateWeightFromDiskReqInput
1079
1123
  ) -> Tuple[bool, str]:
1080
- if self.server_args.tokenizer_worker_num > 1:
1081
- obj = MultiTokenizerWrapper(self.worker_id, obj)
1082
1124
  self.send_to_scheduler.send_pyobj(obj)
1083
1125
  self.model_update_result = asyncio.Future()
1084
1126
  if self.server_args.dp_size == 1:
@@ -1103,84 +1145,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1103
1145
  all_paused_requests = [r.num_paused_requests for r in result]
1104
1146
  return all_success, all_message, all_paused_requests
1105
1147
 
1106
- async def open_session(
1107
- self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
1108
- ):
1109
- self.auto_create_handle_loop()
1110
-
1111
- if obj.session_id is None:
1112
- obj.session_id = uuid.uuid4().hex
1113
- elif obj.session_id in self.session_futures:
1114
- return None
1115
-
1116
- if self.server_args.tokenizer_worker_num > 1:
1117
- obj = MultiTokenizerWrapper(self.worker_id, obj)
1118
- self.send_to_scheduler.send_pyobj(obj)
1119
-
1120
- self.session_futures[obj.session_id] = asyncio.Future()
1121
- session_id = await self.session_futures[obj.session_id]
1122
- del self.session_futures[obj.session_id]
1123
- return session_id
1124
-
1125
- async def close_session(
1126
- self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
1127
- ):
1128
- await self.send_to_scheduler.send_pyobj(obj)
1129
-
1130
- def get_log_request_metadata(self):
1131
- max_length = None
1132
- skip_names = None
1133
- out_skip_names = None
1134
- if self.log_requests:
1135
- if self.log_requests_level == 0:
1136
- max_length = 1 << 30
1137
- skip_names = set(
1138
- [
1139
- "text",
1140
- "input_ids",
1141
- "input_embeds",
1142
- "image_data",
1143
- "audio_data",
1144
- "lora_path",
1145
- "sampling_params",
1146
- ]
1147
- )
1148
- out_skip_names = set(
1149
- [
1150
- "text",
1151
- "output_ids",
1152
- "embedding",
1153
- ]
1154
- )
1155
- elif self.log_requests_level == 1:
1156
- max_length = 1 << 30
1157
- skip_names = set(
1158
- [
1159
- "text",
1160
- "input_ids",
1161
- "input_embeds",
1162
- "image_data",
1163
- "audio_data",
1164
- "lora_path",
1165
- ]
1166
- )
1167
- out_skip_names = set(
1168
- [
1169
- "text",
1170
- "output_ids",
1171
- "embedding",
1172
- ]
1173
- )
1174
- elif self.log_requests_level == 2:
1175
- max_length = 2048
1176
- elif self.log_requests_level == 3:
1177
- max_length = 1 << 30
1178
- else:
1179
- raise ValueError(
1180
- f"Invalid --log-requests-level: {self.log_requests_level=}"
1181
- )
1182
- return max_length, skip_names, out_skip_names
1183
-
1184
1148
  def configure_logging(self, obj: ConfigureLoggingReq):
1185
1149
  if obj.log_requests is not None:
1186
1150
  self.log_requests = obj.log_requests
@@ -1216,11 +1180,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1216
1180
  return background_tasks
1217
1181
 
1218
1182
  def auto_create_handle_loop(self):
1219
- if self.no_create_loop:
1183
+ if self._chosen_loop is not None:
1184
+ assert (
1185
+ asyncio.get_event_loop() == self._chosen_loop
1186
+ ), f"Please ensure only one event loop is ever used with SGLang. Previous loop: {self._chosen_loop}, current loop: {asyncio.get_event_loop()}"
1220
1187
  return
1221
1188
 
1222
- self.no_create_loop = True
1223
1189
  loop = asyncio.get_event_loop()
1190
+ self._chosen_loop = loop
1224
1191
  self.asyncio_tasks.add(
1225
1192
  loop.create_task(print_exception_wrapper(self.handle_loop))
1226
1193
  )
@@ -1339,12 +1306,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1339
1306
  # Drain requests
1340
1307
  while True:
1341
1308
  remain_num_req = len(self.rid_to_state)
1309
+ remaining_rids = list(self.rid_to_state.keys())
1342
1310
 
1343
1311
  if self.server_status == ServerStatus.UnHealthy:
1344
1312
  # if health check failed, we should exit immediately
1345
1313
  logger.error(
1346
- "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1347
- remain_num_req,
1314
+ "Signal SIGTERM received while health check failed. Force exiting."
1348
1315
  )
1349
1316
  self.dump_requests_before_crash()
1350
1317
  break
@@ -1352,13 +1319,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1352
1319
  elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
1353
1320
  # if force shutdown flag set, exit immediately
1354
1321
  logger.error(
1355
- "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
1356
- remain_num_req,
1322
+ "Signal SIGTERM received while force shutdown flag set. Force exiting."
1357
1323
  )
1358
1324
  break
1359
1325
 
1360
1326
  logger.info(
1361
- f"Gracefully exiting... remaining number of requests {remain_num_req}"
1327
+ f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
1362
1328
  )
1363
1329
  if remain_num_req > 0:
1364
1330
  await asyncio.sleep(5)
@@ -1379,7 +1345,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1379
1345
  def _handle_batch_output(
1380
1346
  self,
1381
1347
  recv_obj: Union[
1382
- BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
1348
+ BatchStrOutput,
1349
+ BatchEmbeddingOutput,
1350
+ BatchMultimodalOutput,
1351
+ BatchTokenIDOutput,
1383
1352
  ],
1384
1353
  ):
1385
1354
  for i, rid in enumerate(recv_obj.rids):
@@ -1390,12 +1359,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1390
1359
  )
1391
1360
  continue
1392
1361
 
1393
- origin_rid = rid
1394
- if self.server_args.tokenizer_worker_num > 1:
1395
- origin_rid = get_origin_rid(rid)
1396
1362
  # Build meta_info and return value
1397
1363
  meta_info = {
1398
- "id": origin_rid,
1364
+ "id": rid,
1399
1365
  "finish_reason": recv_obj.finished_reasons[i],
1400
1366
  "prompt_tokens": recv_obj.prompt_tokens[i],
1401
1367
  "weight_version": self.server_args.weight_version,
@@ -1413,7 +1379,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1413
1379
  i,
1414
1380
  )
1415
1381
 
1416
- if not isinstance(recv_obj, BatchEmbeddingOut):
1382
+ if not isinstance(recv_obj, BatchEmbeddingOutput):
1417
1383
  meta_info.update(
1418
1384
  {
1419
1385
  "completion_tokens": recv_obj.completion_tokens[i],
@@ -1424,7 +1390,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1424
1390
  if getattr(recv_obj, "output_hidden_states", None):
1425
1391
  meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
1426
1392
 
1427
- if isinstance(recv_obj, BatchStrOut):
1393
+ if isinstance(recv_obj, BatchStrOutput):
1428
1394
  state.text += recv_obj.output_strs[i]
1429
1395
  if state.obj.stream:
1430
1396
  state.output_ids.extend(recv_obj.output_ids[i])
@@ -1439,7 +1405,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1439
1405
  "output_ids": output_token_ids,
1440
1406
  "meta_info": meta_info,
1441
1407
  }
1442
- elif isinstance(recv_obj, BatchTokenIDOut):
1408
+ elif isinstance(recv_obj, BatchTokenIDOutput):
1443
1409
  if self.server_args.stream_output and state.obj.stream:
1444
1410
  state.output_ids.extend(recv_obj.output_ids[i])
1445
1411
  output_token_ids = state.output_ids[state.last_output_offset :]
@@ -1452,10 +1418,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1452
1418
  "output_ids": output_token_ids,
1453
1419
  "meta_info": meta_info,
1454
1420
  }
1455
- elif isinstance(recv_obj, BatchMultimodalOut):
1421
+ elif isinstance(recv_obj, BatchMultimodalOutput):
1456
1422
  raise NotImplementedError("BatchMultimodalOut not implemented")
1457
1423
  else:
1458
- assert isinstance(recv_obj, BatchEmbeddingOut)
1424
+ assert isinstance(recv_obj, BatchEmbeddingOutput)
1459
1425
  out_dict = {
1460
1426
  "embedding": recv_obj.embeddings[i],
1461
1427
  "meta_info": meta_info,
@@ -1464,7 +1430,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1464
1430
  state.finished = recv_obj.finished_reasons[i] is not None
1465
1431
  if state.finished:
1466
1432
  if self.server_args.speculative_algorithm:
1467
- meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1433
+ self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)
1468
1434
  state.finished_time = time.time()
1469
1435
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1470
1436
 
@@ -1494,7 +1460,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1494
1460
  top_logprobs_num: int,
1495
1461
  token_ids_logprob: List[int],
1496
1462
  return_text_in_logprobs: bool,
1497
- recv_obj: BatchStrOut,
1463
+ recv_obj: BatchStrOutput,
1498
1464
  recv_obj_index: int,
1499
1465
  ):
1500
1466
  if recv_obj.input_token_logprobs_val is None:
@@ -1612,17 +1578,54 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1612
1578
  ret.append(None)
1613
1579
  return ret
1614
1580
 
1615
- def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
1581
+ def _calculate_spec_decoding_metrics(
1582
+ self,
1583
+ meta_info: Dict[str, Any],
1584
+ recv_obj: Union[
1585
+ BatchStrOutput,
1586
+ BatchEmbeddingOutput,
1587
+ BatchMultimodalOutput,
1588
+ BatchTokenIDOutput,
1589
+ ],
1590
+ i: int,
1591
+ ) -> None:
1592
+ """Calculate speculative decoding metrics, such as acceptance rate and acceptance length metrics."""
1593
+ meta_info["spec_accept_rate"] = 0.0
1594
+ meta_info["spec_accept_length"] = 0
1595
+ meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1596
+
1597
+ if (
1598
+ recv_obj.spec_verify_ct[i] > 0
1599
+ and self.server_args.speculative_num_steps is not None
1600
+ and not isinstance(recv_obj, BatchEmbeddingOutput)
1601
+ and hasattr(recv_obj, "spec_accepted_tokens")
1602
+ # Checks that `spec_accepted_tokens[i]` will exist.
1603
+ and len(recv_obj.spec_accepted_tokens) > i
1604
+ ):
1605
+ total_draft_tokens = (
1606
+ recv_obj.spec_verify_ct[i] * self.server_args.speculative_num_steps
1607
+ )
1608
+ accepted_tokens = recv_obj.spec_accepted_tokens[i]
1609
+
1610
+ # Calculate per-request acceptance rate and average acceptance length.
1611
+ if total_draft_tokens > 0:
1612
+ # Calculate acceptance rate: accepted / (steps * lookahead)
1613
+ meta_info["spec_accept_rate"] = accepted_tokens / total_draft_tokens
1614
+ meta_info["spec_accept_length"] = (
1615
+ recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
1616
+ )
1617
+
1618
+ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
1616
1619
  completion_tokens = (
1617
1620
  recv_obj.completion_tokens[i]
1618
1621
  if getattr(recv_obj, "completion_tokens", None)
1619
1622
  else 0
1620
1623
  )
1621
1624
 
1622
- customer_labels = getattr(state.obj, "customer_labels", None)
1625
+ custom_labels = getattr(state.obj, "custom_labels", None)
1623
1626
  labels = (
1624
- {**self.metrics_collector.labels, **customer_labels}
1625
- if customer_labels
1627
+ {**self.metrics_collector.labels, **custom_labels}
1628
+ if custom_labels
1626
1629
  else self.metrics_collector.labels
1627
1630
  )
1628
1631
  if (
@@ -1708,13 +1711,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1708
1711
 
1709
1712
  asyncio.create_task(asyncio.to_thread(background_task))
1710
1713
 
1711
- def _handle_abort_req(self, recv_obj):
1714
+ def _handle_abort_req(self, recv_obj: AbortReq):
1712
1715
  if is_health_check_generate_req(recv_obj):
1713
1716
  return
1714
1717
  state = self.rid_to_state[recv_obj.rid]
1715
- origin_rid = recv_obj.rid
1716
- if self.server_args.tokenizer_worker_num > 1:
1717
- origin_rid = get_origin_rid(origin_rid)
1718
1718
  state.finished = True
1719
1719
  if recv_obj.finished_reason:
1720
1720
  out = {
@@ -1727,7 +1727,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1727
1727
  out = {
1728
1728
  "text": "",
1729
1729
  "meta_info": {
1730
- "id": origin_rid,
1730
+ "id": recv_obj.rid,
1731
1731
  "finish_reason": {
1732
1732
  "type": "abort",
1733
1733
  "message": "Abort before prefill",
@@ -1753,6 +1753,201 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1753
1753
  if len(self.model_update_tmp) == self.server_args.dp_size:
1754
1754
  self.model_update_result.set_result(self.model_update_tmp)
1755
1755
 
1756
+ def _initialize_multi_item_delimiter_text(self):
1757
+ """Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
1758
+ if (
1759
+ hasattr(self.server_args, "multi_item_scoring_delimiter")
1760
+ and self.server_args.multi_item_scoring_delimiter is not None
1761
+ and self.tokenizer is not None
1762
+ ):
1763
+ try:
1764
+ self.multi_item_delimiter_text = self.tokenizer.decode(
1765
+ [self.server_args.multi_item_scoring_delimiter],
1766
+ skip_special_tokens=False,
1767
+ )
1768
+ except Exception as e:
1769
+ logger.warning(
1770
+ f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
1771
+ )
1772
+ self.multi_item_delimiter_text = None
1773
+
1774
+ def _build_multi_item_token_sequence(
1775
+ self, query: List[int], items: List[List[int]], delimiter_token_id: int
1776
+ ) -> List[int]:
1777
+ """
1778
+ Build a single token sequence for multi-item scoring.
1779
+ Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1780
+
1781
+ Args:
1782
+ query: Query token IDs
1783
+ items: List of item token ID sequences
1784
+ delimiter_token_id: Token ID to use as delimiter
1785
+
1786
+ Returns:
1787
+ Combined token sequence
1788
+ """
1789
+ combined_sequence = query[:] # Start with query
1790
+
1791
+ for item in items:
1792
+ combined_sequence.append(delimiter_token_id) # Add delimiter
1793
+ combined_sequence.extend(item) # Add item tokens
1794
+
1795
+ # Add final delimiter after the last item for logprob extraction
1796
+ combined_sequence.append(delimiter_token_id)
1797
+
1798
+ return combined_sequence
1799
+
1800
+ def _extract_logprobs_for_tokens(
1801
+ self, logprobs_data: List, label_token_ids: List[int]
1802
+ ) -> Dict[int, float]:
1803
+ """
1804
+ Extract logprobs for specified token IDs from logprobs data.
1805
+
1806
+ Args:
1807
+ logprobs_data: List of (logprob, token_id, text) tuples
1808
+ label_token_ids: Token IDs to extract logprobs for
1809
+
1810
+ Returns:
1811
+ Dictionary mapping token_id to logprob
1812
+ """
1813
+ logprobs = {}
1814
+ if logprobs_data:
1815
+ for logprob, token_id, _ in logprobs_data:
1816
+ if token_id in label_token_ids:
1817
+ logprobs[token_id] = logprob
1818
+ return logprobs
1819
+
1820
+ def _convert_logprobs_to_scores(
1821
+ self,
1822
+ logprobs: Dict[int, float],
1823
+ label_token_ids: List[int],
1824
+ apply_softmax: bool,
1825
+ ) -> List[float]:
1826
+ """
1827
+ Convert logprobs dictionary to ordered score list.
1828
+
1829
+ Args:
1830
+ logprobs: Dictionary mapping token_id to logprob
1831
+ label_token_ids: Token IDs in desired order
1832
+ apply_softmax: Whether to apply softmax normalization
1833
+
1834
+ Returns:
1835
+ List of scores in the same order as label_token_ids
1836
+ """
1837
+ score_list = [
1838
+ logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
1839
+ ]
1840
+
1841
+ if apply_softmax:
1842
+ score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
1843
+ else:
1844
+ # Convert logprobs to probabilities if not using softmax
1845
+ score_list = [
1846
+ math.exp(x) if x != float("-inf") else 0.0 for x in score_list
1847
+ ]
1848
+
1849
+ return score_list
1850
+
1851
+ def _process_multi_item_scoring_results(
1852
+ self,
1853
+ results: Any,
1854
+ items: List,
1855
+ label_token_ids: List[int],
1856
+ apply_softmax: bool,
1857
+ batch_request=None,
1858
+ ) -> List[List[float]]:
1859
+ """
1860
+ Process results from multi-item scoring request.
1861
+ Extracts logprobs at delimiter positions from input_token_ids_logprobs.
1862
+
1863
+ Args:
1864
+ results: Results from generate_request
1865
+ items: List of items being scored
1866
+ label_token_ids: Token IDs to extract scores for
1867
+ apply_softmax: Whether to apply softmax normalization
1868
+ batch_request: The original batch request containing input sequence
1869
+
1870
+ Returns:
1871
+ List of score lists, one for each item
1872
+ """
1873
+ single_result = results[0] if isinstance(results, list) else results
1874
+
1875
+ # For multi-item scoring, logprobs are in input_token_ids_logprobs
1876
+ input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
1877
+
1878
+ if not input_logprobs:
1879
+ raise RuntimeError(
1880
+ f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
1881
+ "This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
1882
+ )
1883
+
1884
+ scores = []
1885
+ num_items = len(items) if isinstance(items, list) else 1
1886
+
1887
+ # Check if we have the expected number of logprobs
1888
+ expected_logprobs_count = num_items + 1
1889
+ if len(input_logprobs) != expected_logprobs_count:
1890
+ raise RuntimeError(
1891
+ f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
1892
+ f"with {num_items} items, but got {len(input_logprobs)}. "
1893
+ f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
1894
+ )
1895
+
1896
+ # Skip the first delimiter (between query and first item) and process remaining delimiter positions
1897
+ # We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
1898
+ start_idx = 1 if len(input_logprobs) > 1 else 0
1899
+
1900
+ # Process logprobs for each item position (excluding first delimiter)
1901
+ for item_idx in range(num_items):
1902
+ logprob_idx = start_idx + item_idx
1903
+ item_logprobs_data = input_logprobs[logprob_idx]
1904
+ logprobs = self._extract_logprobs_for_tokens(
1905
+ item_logprobs_data, label_token_ids
1906
+ )
1907
+ score_list = self._convert_logprobs_to_scores(
1908
+ logprobs, label_token_ids, apply_softmax
1909
+ )
1910
+ scores.append(score_list)
1911
+
1912
+ return scores
1913
+
1914
+ def _process_single_item_scoring_results(
1915
+ self, results: Any, label_token_ids: List[int], apply_softmax: bool
1916
+ ) -> List[List[float]]:
1917
+ """
1918
+ Process results from single-item scoring request.
1919
+ Single-item scoring results are stored in output_token_ids_logprobs.
1920
+
1921
+ Args:
1922
+ results: Results from generate_request
1923
+ label_token_ids: Token IDs to extract scores for
1924
+ apply_softmax: Whether to apply softmax normalization
1925
+
1926
+ Returns:
1927
+ List of score lists, one for each result
1928
+ """
1929
+ scores = []
1930
+
1931
+ for result in results:
1932
+ # For single-item scoring, logprobs are in output_token_ids_logprobs
1933
+ output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
1934
+
1935
+ if not output_logprobs or len(output_logprobs) == 0:
1936
+ raise RuntimeError(
1937
+ f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
1938
+ )
1939
+
1940
+ # Extract logprobs for the first (and only) position
1941
+ logprobs = self._extract_logprobs_for_tokens(
1942
+ output_logprobs[0], label_token_ids
1943
+ )
1944
+ score_list = self._convert_logprobs_to_scores(
1945
+ logprobs, label_token_ids, apply_softmax
1946
+ )
1947
+ scores.append(score_list)
1948
+
1949
+ return scores
1950
+
1756
1951
  async def score_request(
1757
1952
  self,
1758
1953
  query: Optional[Union[str, List[int]]] = None,
@@ -1763,7 +1958,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1763
1958
  request: Optional[Any] = None,
1764
1959
  ) -> List[List[float]]:
1765
1960
  """
1766
- See Engine.score() for more details.
1961
+ Score the probability of specified token IDs appearing after the given (query + item) pair.
1962
+
1963
+ This method supports two scoring approaches:
1964
+ 1. Single-Item scoring (default): Process each query+item pair independently
1965
+ 2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
1966
+ multiple items into a single sequence using delimiter for efficient processing.
1967
+ Note: item_first parameter is ignored in multi-item scoring mode since it uses
1968
+ a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1969
+
1970
+ Multi-item scoring works with both text and pre-tokenized inputs:
1971
+ - Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
1972
+ - Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
1973
+
1974
+ Args:
1975
+ query: The query text or pre-tokenized query token IDs
1976
+ items: The item text(s) or pre-tokenized item token IDs
1977
+ label_token_ids: List of token IDs to compute probabilities for
1978
+ apply_softmax: Whether to normalize probabilities using softmax
1979
+ item_first: If True, prepend items to query. Ignored for multi-item scoring.
1980
+ request: Optional FastAPI request object
1981
+
1982
+ Returns:
1983
+ List of lists containing probabilities for each item and each label token
1767
1984
  """
1768
1985
  if label_token_ids is None:
1769
1986
  raise ValueError("label_token_ids must be provided")
@@ -1776,9 +1993,17 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1776
1993
  f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
1777
1994
  )
1778
1995
 
1996
+ # Check if multi-item scoring is enabled by presence of delimiter
1997
+ use_multi_item_scoring = (
1998
+ self.server_args.multi_item_scoring_delimiter is not None
1999
+ and self.multi_item_delimiter_text is not None
2000
+ )
2001
+
1779
2002
  batch_request = GenerateReqInput(
1780
2003
  token_ids_logprob=label_token_ids,
1781
2004
  return_logprob=True,
2005
+ # Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
2006
+ logprob_start_len=0 if use_multi_item_scoring else -1,
1782
2007
  stream=False,
1783
2008
  sampling_params={"max_new_tokens": 0},
1784
2009
  )
@@ -1790,12 +2015,23 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1790
2015
  ):
1791
2016
  # Both query and items are text
1792
2017
  items_list = [items] if isinstance(items, str) else items
1793
- if item_first:
1794
- prompts = [f"{item}{query}" for item in items_list]
1795
- else:
1796
- prompts = [f"{query}{item}" for item in items_list]
1797
2018
 
1798
- batch_request.text = prompts
2019
+ if use_multi_item_scoring:
2020
+ # Multi-item scoring: create single prompt with delimiter text
2021
+ # Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
2022
+ # (item_first is ignored for multi-item scoring)
2023
+ delimiter = self.multi_item_delimiter_text
2024
+ combined_items = delimiter.join(items_list)
2025
+ # Add final delimiter after the last item for logprob extraction
2026
+ single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
2027
+ batch_request.text = [single_prompt]
2028
+ else:
2029
+ # Single-item scoring: create separate prompts for each item
2030
+ if item_first:
2031
+ prompts = [f"{item}{query}" for item in items_list]
2032
+ else:
2033
+ prompts = [f"{query}{item}" for item in items_list]
2034
+ batch_request.text = prompts
1799
2035
 
1800
2036
  elif (
1801
2037
  isinstance(query, list)
@@ -1804,61 +2040,38 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1804
2040
  and isinstance(items[0], list)
1805
2041
  ):
1806
2042
  # Both query and items are token IDs
1807
- if item_first:
1808
- input_ids_list = [item + query for item in items]
2043
+ if use_multi_item_scoring:
2044
+ # Multi-item scoring: concatenate with delimiter token ID
2045
+ # Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
2046
+ delimiter_token_id = self.server_args.multi_item_scoring_delimiter
2047
+ combined_input_ids = self._build_multi_item_token_sequence(
2048
+ query, items, delimiter_token_id
2049
+ )
2050
+ batch_request.input_ids = [combined_input_ids]
1809
2051
  else:
1810
- input_ids_list = [query + item for item in items]
1811
-
1812
- batch_request.input_ids = input_ids_list
2052
+ # Single-item scoring: process each item separately
2053
+ if item_first:
2054
+ input_ids_list = [item + query for item in items]
2055
+ else:
2056
+ input_ids_list = [query + item for item in items]
2057
+ batch_request.input_ids = input_ids_list
1813
2058
  else:
1814
2059
  raise ValueError(
1815
2060
  "Invalid combination of query/items types for score_request."
1816
2061
  )
1817
2062
 
1818
2063
  results = await self.generate_request(batch_request, request).__anext__()
1819
- scores = []
1820
-
1821
- for result in results:
1822
- # Get logprobs for each token
1823
- logprobs = {}
1824
-
1825
- # For scoring requests, we read from output_token_ids_logprobs since we want
1826
- # the logprobs for specific tokens mentioned in the label_token_ids at
1827
- # the next position after the last token in the prompt
1828
- output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
1829
2064
 
1830
- # Check if output_logprobs is properly populated
1831
- if (
1832
- output_logprobs is None
1833
- or not output_logprobs
1834
- or len(output_logprobs) == 0
1835
- ):
1836
- raise RuntimeError(
1837
- f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
1838
- "This indicates token_ids_logprobs were not computed properly for the scoring request."
1839
- )
1840
-
1841
- for logprob, token_id, _ in output_logprobs[0]:
1842
- if token_id in label_token_ids:
1843
- logprobs[token_id] = logprob
1844
-
1845
- # Get scores in order of label_token_ids
1846
- score_list = [
1847
- logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
1848
- ]
1849
-
1850
- # Apply softmax to logprobs if needed
1851
- if apply_softmax:
1852
- score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
1853
- else:
1854
- # Convert logprobs to probabilities if not using softmax
1855
- score_list = [
1856
- math.exp(x) if x != float("-inf") else 0.0 for x in score_list
1857
- ]
1858
-
1859
- scores.append(score_list)
1860
-
1861
- return scores
2065
+ if use_multi_item_scoring:
2066
+ # Multi-item scoring: extract scores from input_token_ids_logprobs
2067
+ return self._process_multi_item_scoring_results(
2068
+ results, items, label_token_ids, apply_softmax, batch_request
2069
+ )
2070
+ else:
2071
+ # Single-item scoring: process each result separately
2072
+ return self._process_single_item_scoring_results(
2073
+ results, label_token_ids, apply_softmax
2074
+ )
1862
2075
 
1863
2076
  async def watch_load_thread(self):
1864
2077
  # Only for dp_controller when dp_size > 1
@@ -1874,6 +2087,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1874
2087
  load_udpate_req = WatchLoadUpdateReq(loads=loads)
1875
2088
  self.send_to_scheduler.send_pyobj(load_udpate_req)
1876
2089
 
2090
+ def _trace_request_start(
2091
+ self,
2092
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
2093
+ created_time: Optional[float] = None,
2094
+ ):
2095
+ if obj.is_single:
2096
+ bootstrap_room = (
2097
+ obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
2098
+ )
2099
+ trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
2100
+ trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
2101
+ else:
2102
+ for i in range(len(obj.rid)):
2103
+ bootstrap_room = (
2104
+ obj.bootstrap_room[i]
2105
+ if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
2106
+ else None
2107
+ )
2108
+ trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
2109
+ trace_slice_start(
2110
+ "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
2111
+ )
2112
+
1877
2113
 
1878
2114
  class ServerStatus(Enum):
1879
2115
  Up = "Up"
@@ -1919,7 +2155,7 @@ class SignalHandler:
1919
2155
 
1920
2156
  def running_phase_sigquit_handler(self, signum=None, frame=None):
1921
2157
  logger.error(
1922
- "Received sigquit from a child process. It usually means the child failed."
2158
+ f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
1923
2159
  )
1924
2160
  self.tokenizer_manager.dump_requests_before_crash()
1925
2161
  kill_process_tree(os.getpid())