sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@
16
16
  import asyncio
17
17
  import copy
18
18
  import dataclasses
19
- import json
20
19
  import logging
21
20
  import math
22
21
  import os
@@ -25,7 +24,6 @@ import signal
25
24
  import sys
26
25
  import threading
27
26
  import time
28
- import uuid
29
27
  from collections import deque
30
28
  from contextlib import nullcontext
31
29
  from datetime import datetime
@@ -34,13 +32,13 @@ from http import HTTPStatus
34
32
  from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
35
33
 
36
34
  import fastapi
35
+ import orjson
37
36
  import torch
38
37
  import uvloop
39
38
  import zmq
40
39
  import zmq.asyncio
41
40
  from fastapi import BackgroundTasks
42
41
 
43
- from sglang.srt.aio_rwlock import RWLock
44
42
  from sglang.srt.configs.model_config import ModelConfig
45
43
  from sglang.srt.disaggregation.utils import DisaggregationMode
46
44
  from sglang.srt.lora.lora_registry import LoRARegistry
@@ -60,7 +58,6 @@ from sglang.srt.managers.io_struct import (
60
58
  GenerateReqInput,
61
59
  GetLoadReqInput,
62
60
  HealthCheckOutput,
63
- MultiTokenizerWrapper,
64
61
  OpenSessionReqOutput,
65
62
  SessionParams,
66
63
  TokenizedEmbeddingReqInput,
@@ -90,10 +87,10 @@ from sglang.srt.utils import (
90
87
  dataclass_to_string_truncated,
91
88
  freeze_gc,
92
89
  get_bool_env_var,
93
- get_origin_rid,
94
90
  get_zmq_socket,
95
91
  kill_process_tree,
96
92
  )
93
+ from sglang.srt.utils.aio_rwlock import RWLock
97
94
  from sglang.srt.utils.hf_transformers_utils import (
98
95
  get_processor,
99
96
  get_tokenizer,
@@ -157,7 +154,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
157
154
  self.log_requests = server_args.log_requests
158
155
  self.log_requests_level = server_args.log_requests_level
159
156
  self.preferred_sampling_params = (
160
- json.loads(server_args.preferred_sampling_params)
157
+ orjson.loads(server_args.preferred_sampling_params)
161
158
  if server_args.preferred_sampling_params
162
159
  else None
163
160
  )
@@ -173,7 +170,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
173
170
  self.context_len = self.model_config.context_len
174
171
  self.image_token_id = self.model_config.image_token_id
175
172
  self.max_req_input_len = None # Will be set later in engine.py
176
-
177
173
  speculative_algorithm = SpeculativeAlgorithm.from_string(
178
174
  server_args.speculative_algorithm
179
175
  )
@@ -183,6 +179,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
183
179
  else server_args.speculative_num_draft_tokens
184
180
  )
185
181
 
182
+ # Initialize tokenizer and processor
186
183
  if self.model_config.is_multimodal:
187
184
  import_processors("sglang.srt.multimodal.processors")
188
185
  try:
@@ -223,6 +220,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
223
220
  self.processor = _processor
224
221
  self.tokenizer = get_tokenizer_from_processor(self.processor)
225
222
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
223
+ self._initialize_multi_item_delimiter_text()
226
224
  else:
227
225
  self.mm_processor = self.processor = None
228
226
 
@@ -235,6 +233,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
235
233
  trust_remote_code=server_args.trust_remote_code,
236
234
  revision=server_args.revision,
237
235
  )
236
+ self._initialize_multi_item_delimiter_text()
237
+
238
238
  # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
239
239
  if (
240
240
  server_args.enable_dynamic_batch_tokenizer
@@ -253,18 +253,23 @@ class TokenizerManager(TokenizerCommunicatorMixin):
253
253
  self.recv_from_detokenizer = get_zmq_socket(
254
254
  context, zmq.PULL, port_args.tokenizer_ipc_name, True
255
255
  )
256
- if self.server_args.tokenizer_worker_num > 1:
257
- # Use tokenizer_worker_ipc_name in multi-tokenizer mode
256
+ if self.server_args.tokenizer_worker_num == 1:
258
257
  self.send_to_scheduler = get_zmq_socket(
259
- context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
258
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
260
259
  )
261
260
  else:
262
- self.send_to_scheduler = get_zmq_socket(
263
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
261
+ from sglang.srt.managers.multi_tokenizer_mixin import SenderWrapper
262
+
263
+ # Use tokenizer_worker_ipc_name in multi-tokenizer mode
264
+ send_to_scheduler = get_zmq_socket(
265
+ context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
264
266
  )
265
267
 
268
+ # Make sure that each request carries the tokenizer_ipc_name for response routing
269
+ self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler)
270
+
266
271
  # Request states
267
- self.no_create_loop = False
272
+ self._chosen_loop = None
268
273
  self.rid_to_state: Dict[str, ReqState] = {}
269
274
  self.asyncio_tasks = set()
270
275
 
@@ -273,6 +278,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
273
278
  self.gracefully_exit = False
274
279
  self.last_receive_tstamp = 0
275
280
 
281
+ # Initial weights status
282
+ self.initial_weights_loaded = True
283
+ if server_args.checkpoint_engine_wait_weights_before_ready:
284
+ self.initial_weights_loaded = False
285
+
276
286
  # Dumping
277
287
  self.dump_requests_folder = "" # By default do not dump
278
288
  self.dump_requests_threshold = 1000
@@ -304,6 +314,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
304
314
  # LoRA updates and inference to overlap.
305
315
  self.lora_update_lock = asyncio.Lock()
306
316
 
317
+ # Disaggregation
307
318
  self.disaggregation_mode = DisaggregationMode(
308
319
  self.server_args.disaggregation_mode
309
320
  )
@@ -355,7 +366,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
355
366
  (
356
367
  FreezeGCReq,
357
368
  lambda x: None,
358
- ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
369
+ ),
370
+ # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
359
371
  (HealthCheckOutput, lambda x: None),
360
372
  ]
361
373
  )
@@ -372,13 +384,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
372
384
  obj.normalize_batch_and_arguments()
373
385
 
374
386
  if self.server_args.tokenizer_worker_num > 1:
375
- # Modify rid, add worker_id
376
- if isinstance(obj.rid, list):
377
- # If it's an array, add worker_id prefix to each element
378
- obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
379
- else:
380
- # If it's a single value, add worker_id prefix
381
- obj.rid = f"{self.worker_id}_{obj.rid}"
387
+ self._attach_multi_http_worker_info(obj)
382
388
 
383
389
  if self.enable_trace:
384
390
  self._trace_request_start(obj, created_time)
@@ -582,9 +588,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
582
588
  )
583
589
 
584
590
  if self.mm_processor and obj.contains_mm_input():
585
- if not isinstance(obj.image_data, list):
591
+ if obj.image_data is not None and not isinstance(obj.image_data, list):
586
592
  obj.image_data = [obj.image_data]
587
- if not isinstance(obj.audio_data, list):
593
+ if obj.audio_data is not None and not isinstance(obj.audio_data, list):
588
594
  obj.audio_data = [obj.audio_data]
589
595
  mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
590
596
  image_data=obj.image_data,
@@ -724,6 +730,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
724
730
  obj.token_ids_logprob,
725
731
  obj.stream,
726
732
  rid=obj.rid,
733
+ http_worker_ipc=obj.http_worker_ipc,
727
734
  bootstrap_host=obj.bootstrap_host,
728
735
  bootstrap_port=obj.bootstrap_port,
729
736
  bootstrap_room=obj.bootstrap_room,
@@ -745,6 +752,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
745
752
  sampling_params,
746
753
  rid=obj.rid,
747
754
  priority=obj.priority,
755
+ http_worker_ipc=obj.http_worker_ipc,
748
756
  )
749
757
 
750
758
  return tokenized_obj
@@ -755,6 +763,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
755
763
  """Handle batch tokenization for text inputs only."""
756
764
  logger.debug(f"Starting batch tokenization for {batch_size} text requests")
757
765
 
766
+ # If batch does not have text nothing to tokenize
767
+ # so lets construct the return object
768
+ if not self._batch_has_text(batch_size, obj):
769
+ # All requests already have input_ids, no need to tokenize
770
+ return [await self._tokenize_one_request(obj[i]) for i in range(batch_size)]
771
+
772
+ self._validate_batch_tokenization_constraints(batch_size, obj)
773
+
758
774
  # Collect requests and texts
759
775
  requests = [obj[i] for i in range(batch_size)]
760
776
  texts = [req.text for req in requests]
@@ -804,6 +820,30 @@ class TokenizerManager(TokenizerCommunicatorMixin):
804
820
  "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
805
821
  )
806
822
 
823
+ def _batch_has_text(
824
+ self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
825
+ ) -> bool:
826
+ """Check if any request in the batch contains text input."""
827
+ for i in range(batch_size):
828
+ if obj[i].text:
829
+ return True
830
+ elif self.is_generation and obj[i].contains_mm_input():
831
+ return True
832
+
833
+ return False
834
+
835
+ def _should_use_batch_tokenization(self, batch_size, requests) -> bool:
836
+ """Return True if we should run the tokenizer in batch mode.
837
+
838
+ Current policy:
839
+ - Respect explicit server flag `enable_tokenizer_batch_encode`.
840
+ - Or, if no request has text or multimodal input (all use pre-tokenized input_ids or input_embeds), batch the requests without tokenization.
841
+ """
842
+ return batch_size > 0 and (
843
+ self.server_args.enable_tokenizer_batch_encode
844
+ or not self._batch_has_text(batch_size, requests)
845
+ )
846
+
807
847
  def _send_one_request(
808
848
  self,
809
849
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -938,13 +978,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
938
978
  generators = []
939
979
  rids = []
940
980
  if getattr(obj, "parallel_sample_num", 1) == 1:
941
- if self.server_args.enable_tokenizer_batch_encode:
942
- # Validate batch tokenization constraints
943
- self._validate_batch_tokenization_constraints(batch_size, obj)
944
-
981
+ if self._should_use_batch_tokenization(batch_size, obj):
945
982
  tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
946
-
947
- # Send as a single batched request
948
983
  self._send_batch_request(obj, tokenized_objs, created_time)
949
984
 
950
985
  # Set up generators for each request in the batch
@@ -1078,8 +1113,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1078
1113
  async def _wait_for_model_update_from_disk(
1079
1114
  self, obj: UpdateWeightFromDiskReqInput
1080
1115
  ) -> Tuple[bool, str]:
1081
- if self.server_args.tokenizer_worker_num > 1:
1082
- obj = MultiTokenizerWrapper(self.worker_id, obj)
1083
1116
  self.send_to_scheduler.send_pyobj(obj)
1084
1117
  self.model_update_result = asyncio.Future()
1085
1118
  if self.server_args.dp_size == 1:
@@ -1139,11 +1172,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1139
1172
  return background_tasks
1140
1173
 
1141
1174
  def auto_create_handle_loop(self):
1142
- if self.no_create_loop:
1175
+ if self._chosen_loop is not None:
1176
+ assert (
1177
+ asyncio.get_event_loop() == self._chosen_loop
1178
+ ), f"Please ensure only one event loop is ever used with SGLang. Previous loop: {self._chosen_loop}, current loop: {asyncio.get_event_loop()}"
1143
1179
  return
1144
1180
 
1145
- self.no_create_loop = True
1146
1181
  loop = asyncio.get_event_loop()
1182
+ self._chosen_loop = loop
1147
1183
  self.asyncio_tasks.add(
1148
1184
  loop.create_task(print_exception_wrapper(self.handle_loop))
1149
1185
  )
@@ -1315,12 +1351,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1315
1351
  )
1316
1352
  continue
1317
1353
 
1318
- origin_rid = rid
1319
- if self.server_args.tokenizer_worker_num > 1:
1320
- origin_rid = get_origin_rid(rid)
1321
1354
  # Build meta_info and return value
1322
1355
  meta_info = {
1323
- "id": origin_rid,
1356
+ "id": rid,
1324
1357
  "finish_reason": recv_obj.finished_reasons[i],
1325
1358
  "prompt_tokens": recv_obj.prompt_tokens[i],
1326
1359
  "weight_version": self.server_args.weight_version,
@@ -1389,7 +1422,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1389
1422
  state.finished = recv_obj.finished_reasons[i] is not None
1390
1423
  if state.finished:
1391
1424
  if self.server_args.speculative_algorithm:
1392
- meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1425
+ self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)
1393
1426
  state.finished_time = time.time()
1394
1427
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1395
1428
 
@@ -1537,6 +1570,43 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1537
1570
  ret.append(None)
1538
1571
  return ret
1539
1572
 
1573
+ def _calculate_spec_decoding_metrics(
1574
+ self,
1575
+ meta_info: Dict[str, Any],
1576
+ recv_obj: Union[
1577
+ BatchStrOutput,
1578
+ BatchEmbeddingOutput,
1579
+ BatchMultimodalOutput,
1580
+ BatchTokenIDOutput,
1581
+ ],
1582
+ i: int,
1583
+ ) -> None:
1584
+ """Calculate speculative decoding metrics, such as acceptance rate and acceptance length metrics."""
1585
+ meta_info["spec_accept_rate"] = 0.0
1586
+ meta_info["spec_accept_length"] = 0
1587
+ meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1588
+
1589
+ if (
1590
+ recv_obj.spec_verify_ct[i] > 0
1591
+ and self.server_args.speculative_num_steps is not None
1592
+ and not isinstance(recv_obj, BatchEmbeddingOutput)
1593
+ and hasattr(recv_obj, "spec_accepted_tokens")
1594
+ # Checks that `spec_accepted_tokens[i]` will exist.
1595
+ and len(recv_obj.spec_accepted_tokens) > i
1596
+ ):
1597
+ total_draft_tokens = (
1598
+ recv_obj.spec_verify_ct[i] * self.server_args.speculative_num_steps
1599
+ )
1600
+ accepted_tokens = recv_obj.spec_accepted_tokens[i]
1601
+
1602
+ # Calculate per-request acceptance rate and average acceptance length.
1603
+ if total_draft_tokens > 0:
1604
+ # Calculate acceptance rate: accepted / (steps * lookahead)
1605
+ meta_info["spec_accept_rate"] = accepted_tokens / total_draft_tokens
1606
+ meta_info["spec_accept_length"] = (
1607
+ recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
1608
+ )
1609
+
1540
1610
  def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
1541
1611
  completion_tokens = (
1542
1612
  recv_obj.completion_tokens[i]
@@ -1637,9 +1707,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1637
1707
  if is_health_check_generate_req(recv_obj):
1638
1708
  return
1639
1709
  state = self.rid_to_state[recv_obj.rid]
1640
- origin_rid = recv_obj.rid
1641
- if self.server_args.tokenizer_worker_num > 1:
1642
- origin_rid = get_origin_rid(origin_rid)
1643
1710
  state.finished = True
1644
1711
  if recv_obj.finished_reason:
1645
1712
  out = {
@@ -1652,7 +1719,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1652
1719
  out = {
1653
1720
  "text": "",
1654
1721
  "meta_info": {
1655
- "id": origin_rid,
1722
+ "id": recv_obj.rid,
1656
1723
  "finish_reason": {
1657
1724
  "type": "abort",
1658
1725
  "message": "Abort before prefill",
@@ -1678,6 +1745,201 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1678
1745
  if len(self.model_update_tmp) == self.server_args.dp_size:
1679
1746
  self.model_update_result.set_result(self.model_update_tmp)
1680
1747
 
1748
+ def _initialize_multi_item_delimiter_text(self):
1749
+ """Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
1750
+ if (
1751
+ hasattr(self.server_args, "multi_item_scoring_delimiter")
1752
+ and self.server_args.multi_item_scoring_delimiter is not None
1753
+ and self.tokenizer is not None
1754
+ ):
1755
+ try:
1756
+ self.multi_item_delimiter_text = self.tokenizer.decode(
1757
+ [self.server_args.multi_item_scoring_delimiter],
1758
+ skip_special_tokens=False,
1759
+ )
1760
+ except Exception as e:
1761
+ logger.warning(
1762
+ f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
1763
+ )
1764
+ self.multi_item_delimiter_text = None
1765
+
1766
+ def _build_multi_item_token_sequence(
1767
+ self, query: List[int], items: List[List[int]], delimiter_token_id: int
1768
+ ) -> List[int]:
1769
+ """
1770
+ Build a single token sequence for multi-item scoring.
1771
+ Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1772
+
1773
+ Args:
1774
+ query: Query token IDs
1775
+ items: List of item token ID sequences
1776
+ delimiter_token_id: Token ID to use as delimiter
1777
+
1778
+ Returns:
1779
+ Combined token sequence
1780
+ """
1781
+ combined_sequence = query[:] # Start with query
1782
+
1783
+ for item in items:
1784
+ combined_sequence.append(delimiter_token_id) # Add delimiter
1785
+ combined_sequence.extend(item) # Add item tokens
1786
+
1787
+ # Add final delimiter after the last item for logprob extraction
1788
+ combined_sequence.append(delimiter_token_id)
1789
+
1790
+ return combined_sequence
1791
+
1792
+ def _extract_logprobs_for_tokens(
1793
+ self, logprobs_data: List, label_token_ids: List[int]
1794
+ ) -> Dict[int, float]:
1795
+ """
1796
+ Extract logprobs for specified token IDs from logprobs data.
1797
+
1798
+ Args:
1799
+ logprobs_data: List of (logprob, token_id, text) tuples
1800
+ label_token_ids: Token IDs to extract logprobs for
1801
+
1802
+ Returns:
1803
+ Dictionary mapping token_id to logprob
1804
+ """
1805
+ logprobs = {}
1806
+ if logprobs_data:
1807
+ for logprob, token_id, _ in logprobs_data:
1808
+ if token_id in label_token_ids:
1809
+ logprobs[token_id] = logprob
1810
+ return logprobs
1811
+
1812
+ def _convert_logprobs_to_scores(
1813
+ self,
1814
+ logprobs: Dict[int, float],
1815
+ label_token_ids: List[int],
1816
+ apply_softmax: bool,
1817
+ ) -> List[float]:
1818
+ """
1819
+ Convert logprobs dictionary to ordered score list.
1820
+
1821
+ Args:
1822
+ logprobs: Dictionary mapping token_id to logprob
1823
+ label_token_ids: Token IDs in desired order
1824
+ apply_softmax: Whether to apply softmax normalization
1825
+
1826
+ Returns:
1827
+ List of scores in the same order as label_token_ids
1828
+ """
1829
+ score_list = [
1830
+ logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
1831
+ ]
1832
+
1833
+ if apply_softmax:
1834
+ score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
1835
+ else:
1836
+ # Convert logprobs to probabilities if not using softmax
1837
+ score_list = [
1838
+ math.exp(x) if x != float("-inf") else 0.0 for x in score_list
1839
+ ]
1840
+
1841
+ return score_list
1842
+
1843
+ def _process_multi_item_scoring_results(
1844
+ self,
1845
+ results: Any,
1846
+ items: List,
1847
+ label_token_ids: List[int],
1848
+ apply_softmax: bool,
1849
+ batch_request=None,
1850
+ ) -> List[List[float]]:
1851
+ """
1852
+ Process results from multi-item scoring request.
1853
+ Extracts logprobs at delimiter positions from input_token_ids_logprobs.
1854
+
1855
+ Args:
1856
+ results: Results from generate_request
1857
+ items: List of items being scored
1858
+ label_token_ids: Token IDs to extract scores for
1859
+ apply_softmax: Whether to apply softmax normalization
1860
+ batch_request: The original batch request containing input sequence
1861
+
1862
+ Returns:
1863
+ List of score lists, one for each item
1864
+ """
1865
+ single_result = results[0] if isinstance(results, list) else results
1866
+
1867
+ # For multi-item scoring, logprobs are in input_token_ids_logprobs
1868
+ input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
1869
+
1870
+ if not input_logprobs:
1871
+ raise RuntimeError(
1872
+ f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
1873
+ "This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
1874
+ )
1875
+
1876
+ scores = []
1877
+ num_items = len(items) if isinstance(items, list) else 1
1878
+
1879
+ # Check if we have the expected number of logprobs
1880
+ expected_logprobs_count = num_items + 1
1881
+ if len(input_logprobs) != expected_logprobs_count:
1882
+ raise RuntimeError(
1883
+ f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
1884
+ f"with {num_items} items, but got {len(input_logprobs)}. "
1885
+ f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
1886
+ )
1887
+
1888
+ # Skip the first delimiter (between query and first item) and process remaining delimiter positions
1889
+ # We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
1890
+ start_idx = 1 if len(input_logprobs) > 1 else 0
1891
+
1892
+ # Process logprobs for each item position (excluding first delimiter)
1893
+ for item_idx in range(num_items):
1894
+ logprob_idx = start_idx + item_idx
1895
+ item_logprobs_data = input_logprobs[logprob_idx]
1896
+ logprobs = self._extract_logprobs_for_tokens(
1897
+ item_logprobs_data, label_token_ids
1898
+ )
1899
+ score_list = self._convert_logprobs_to_scores(
1900
+ logprobs, label_token_ids, apply_softmax
1901
+ )
1902
+ scores.append(score_list)
1903
+
1904
+ return scores
1905
+
1906
+ def _process_single_item_scoring_results(
1907
+ self, results: Any, label_token_ids: List[int], apply_softmax: bool
1908
+ ) -> List[List[float]]:
1909
+ """
1910
+ Process results from single-item scoring request.
1911
+ Single-item scoring results are stored in output_token_ids_logprobs.
1912
+
1913
+ Args:
1914
+ results: Results from generate_request
1915
+ label_token_ids: Token IDs to extract scores for
1916
+ apply_softmax: Whether to apply softmax normalization
1917
+
1918
+ Returns:
1919
+ List of score lists, one for each result
1920
+ """
1921
+ scores = []
1922
+
1923
+ for result in results:
1924
+ # For single-item scoring, logprobs are in output_token_ids_logprobs
1925
+ output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
1926
+
1927
+ if not output_logprobs or len(output_logprobs) == 0:
1928
+ raise RuntimeError(
1929
+ f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
1930
+ )
1931
+
1932
+ # Extract logprobs for the first (and only) position
1933
+ logprobs = self._extract_logprobs_for_tokens(
1934
+ output_logprobs[0], label_token_ids
1935
+ )
1936
+ score_list = self._convert_logprobs_to_scores(
1937
+ logprobs, label_token_ids, apply_softmax
1938
+ )
1939
+ scores.append(score_list)
1940
+
1941
+ return scores
1942
+
1681
1943
  async def score_request(
1682
1944
  self,
1683
1945
  query: Optional[Union[str, List[int]]] = None,
@@ -1688,7 +1950,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1688
1950
  request: Optional[Any] = None,
1689
1951
  ) -> List[List[float]]:
1690
1952
  """
1691
- See Engine.score() for more details.
1953
+ Score the probability of specified token IDs appearing after the given (query + item) pair.
1954
+
1955
+ This method supports two scoring approaches:
1956
+ 1. Single-Item scoring (default): Process each query+item pair independently
1957
+ 2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
1958
+ multiple items into a single sequence using delimiter for efficient processing.
1959
+ Note: item_first parameter is ignored in multi-item scoring mode since it uses
1960
+ a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1961
+
1962
+ Multi-item scoring works with both text and pre-tokenized inputs:
1963
+ - Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
1964
+ - Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
1965
+
1966
+ Args:
1967
+ query: The query text or pre-tokenized query token IDs
1968
+ items: The item text(s) or pre-tokenized item token IDs
1969
+ label_token_ids: List of token IDs to compute probabilities for
1970
+ apply_softmax: Whether to normalize probabilities using softmax
1971
+ item_first: If True, prepend items to query. Ignored for multi-item scoring.
1972
+ request: Optional FastAPI request object
1973
+
1974
+ Returns:
1975
+ List of lists containing probabilities for each item and each label token
1692
1976
  """
1693
1977
  if label_token_ids is None:
1694
1978
  raise ValueError("label_token_ids must be provided")
@@ -1701,9 +1985,17 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1701
1985
  f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
1702
1986
  )
1703
1987
 
1988
+ # Check if multi-item scoring is enabled by presence of delimiter
1989
+ use_multi_item_scoring = (
1990
+ self.server_args.multi_item_scoring_delimiter is not None
1991
+ and self.multi_item_delimiter_text is not None
1992
+ )
1993
+
1704
1994
  batch_request = GenerateReqInput(
1705
1995
  token_ids_logprob=label_token_ids,
1706
1996
  return_logprob=True,
1997
+ # Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
1998
+ logprob_start_len=0 if use_multi_item_scoring else -1,
1707
1999
  stream=False,
1708
2000
  sampling_params={"max_new_tokens": 0},
1709
2001
  )
@@ -1715,12 +2007,23 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1715
2007
  ):
1716
2008
  # Both query and items are text
1717
2009
  items_list = [items] if isinstance(items, str) else items
1718
- if item_first:
1719
- prompts = [f"{item}{query}" for item in items_list]
1720
- else:
1721
- prompts = [f"{query}{item}" for item in items_list]
1722
2010
 
1723
- batch_request.text = prompts
2011
+ if use_multi_item_scoring:
2012
+ # Multi-item scoring: create single prompt with delimiter text
2013
+ # Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
2014
+ # (item_first is ignored for multi-item scoring)
2015
+ delimiter = self.multi_item_delimiter_text
2016
+ combined_items = delimiter.join(items_list)
2017
+ # Add final delimiter after the last item for logprob extraction
2018
+ single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
2019
+ batch_request.text = [single_prompt]
2020
+ else:
2021
+ # Single-item scoring: create separate prompts for each item
2022
+ if item_first:
2023
+ prompts = [f"{item}{query}" for item in items_list]
2024
+ else:
2025
+ prompts = [f"{query}{item}" for item in items_list]
2026
+ batch_request.text = prompts
1724
2027
 
1725
2028
  elif (
1726
2029
  isinstance(query, list)
@@ -1729,61 +2032,38 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1729
2032
  and isinstance(items[0], list)
1730
2033
  ):
1731
2034
  # Both query and items are token IDs
1732
- if item_first:
1733
- input_ids_list = [item + query for item in items]
2035
+ if use_multi_item_scoring:
2036
+ # Multi-item scoring: concatenate with delimiter token ID
2037
+ # Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
2038
+ delimiter_token_id = self.server_args.multi_item_scoring_delimiter
2039
+ combined_input_ids = self._build_multi_item_token_sequence(
2040
+ query, items, delimiter_token_id
2041
+ )
2042
+ batch_request.input_ids = [combined_input_ids]
1734
2043
  else:
1735
- input_ids_list = [query + item for item in items]
1736
-
1737
- batch_request.input_ids = input_ids_list
2044
+ # Single-item scoring: process each item separately
2045
+ if item_first:
2046
+ input_ids_list = [item + query for item in items]
2047
+ else:
2048
+ input_ids_list = [query + item for item in items]
2049
+ batch_request.input_ids = input_ids_list
1738
2050
  else:
1739
2051
  raise ValueError(
1740
2052
  "Invalid combination of query/items types for score_request."
1741
2053
  )
1742
2054
 
1743
2055
  results = await self.generate_request(batch_request, request).__anext__()
1744
- scores = []
1745
2056
 
1746
- for result in results:
1747
- # Get logprobs for each token
1748
- logprobs = {}
1749
-
1750
- # For scoring requests, we read from output_token_ids_logprobs since we want
1751
- # the logprobs for specific tokens mentioned in the label_token_ids at
1752
- # the next position after the last token in the prompt
1753
- output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
1754
-
1755
- # Check if output_logprobs is properly populated
1756
- if (
1757
- output_logprobs is None
1758
- or not output_logprobs
1759
- or len(output_logprobs) == 0
1760
- ):
1761
- raise RuntimeError(
1762
- f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
1763
- "This indicates token_ids_logprobs were not computed properly for the scoring request."
1764
- )
1765
-
1766
- for logprob, token_id, _ in output_logprobs[0]:
1767
- if token_id in label_token_ids:
1768
- logprobs[token_id] = logprob
1769
-
1770
- # Get scores in order of label_token_ids
1771
- score_list = [
1772
- logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
1773
- ]
1774
-
1775
- # Apply softmax to logprobs if needed
1776
- if apply_softmax:
1777
- score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
1778
- else:
1779
- # Convert logprobs to probabilities if not using softmax
1780
- score_list = [
1781
- math.exp(x) if x != float("-inf") else 0.0 for x in score_list
1782
- ]
1783
-
1784
- scores.append(score_list)
1785
-
1786
- return scores
2057
+ if use_multi_item_scoring:
2058
+ # Multi-item scoring: extract scores from input_token_ids_logprobs
2059
+ return self._process_multi_item_scoring_results(
2060
+ results, items, label_token_ids, apply_softmax, batch_request
2061
+ )
2062
+ else:
2063
+ # Single-item scoring: process each result separately
2064
+ return self._process_single_item_scoring_results(
2065
+ results, label_token_ids, apply_softmax
2066
+ )
1787
2067
 
1788
2068
  async def watch_load_thread(self):
1789
2069
  # Only for dp_controller when dp_size > 1