sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
31
31
  from datetime import datetime
32
32
  from enum import Enum
33
33
  from http import HTTPStatus
34
- from typing import (
35
- Any,
36
- Awaitable,
37
- Deque,
38
- Dict,
39
- Generic,
40
- List,
41
- Optional,
42
- Tuple,
43
- TypeVar,
44
- Union,
45
- )
34
+ from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
46
35
 
47
36
  import fastapi
48
37
  import torch
@@ -53,80 +42,49 @@ from fastapi import BackgroundTasks
53
42
 
54
43
  from sglang.srt.aio_rwlock import RWLock
55
44
  from sglang.srt.configs.model_config import ModelConfig
56
- from sglang.srt.disaggregation.utils import (
57
- DisaggregationMode,
58
- KVClassType,
59
- TransferBackend,
60
- get_kv_class,
61
- )
62
- from sglang.srt.hf_transformers_utils import (
63
- get_processor,
64
- get_tokenizer,
65
- get_tokenizer_from_processor,
66
- )
67
- from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
45
+ from sglang.srt.disaggregation.utils import DisaggregationMode
46
+ from sglang.srt.lora.lora_registry import LoRARegistry
47
+ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
48
+ from sglang.srt.managers.disagg_service import start_disagg_service
68
49
  from sglang.srt.managers.io_struct import (
69
50
  AbortReq,
70
- BatchEmbeddingOut,
71
- BatchMultimodalOut,
72
- BatchStrOut,
73
- BatchTokenIDOut,
51
+ BatchEmbeddingOutput,
52
+ BatchMultimodalOutput,
53
+ BatchStrOutput,
54
+ BatchTokenIDOutput,
74
55
  BatchTokenizedEmbeddingReqInput,
75
56
  BatchTokenizedGenerateReqInput,
76
- ClearHiCacheReqInput,
77
- ClearHiCacheReqOutput,
78
- CloseSessionReqInput,
79
57
  ConfigureLoggingReq,
80
58
  EmbeddingReqInput,
81
- ExpertDistributionReq,
82
- ExpertDistributionReqOutput,
83
- FlushCacheReqInput,
84
- FlushCacheReqOutput,
85
59
  FreezeGCReq,
86
60
  GenerateReqInput,
87
- GetInternalStateReq,
88
- GetInternalStateReqOutput,
89
- GetWeightsByNameReqInput,
90
- GetWeightsByNameReqOutput,
61
+ GetLoadReqInput,
91
62
  HealthCheckOutput,
92
- InitWeightsUpdateGroupReqInput,
93
- InitWeightsUpdateGroupReqOutput,
94
- LoadLoRAAdapterReqInput,
95
- LoadLoRAAdapterReqOutput,
96
- LoRAUpdateResult,
97
- MultiTokenizerWarpper,
98
- OpenSessionReqInput,
63
+ MultiTokenizerWrapper,
99
64
  OpenSessionReqOutput,
100
- ProfileReq,
101
- ProfileReqOutput,
102
- ProfileReqType,
103
- ReleaseMemoryOccupationReqInput,
104
- ReleaseMemoryOccupationReqOutput,
105
- ResumeMemoryOccupationReqInput,
106
- ResumeMemoryOccupationReqOutput,
107
65
  SessionParams,
108
- SetInternalStateReq,
109
- SetInternalStateReqOutput,
110
- SlowDownReqInput,
111
- SlowDownReqOutput,
112
66
  TokenizedEmbeddingReqInput,
113
67
  TokenizedGenerateReqInput,
114
- UnloadLoRAAdapterReqInput,
115
- UnloadLoRAAdapterReqOutput,
116
68
  UpdateWeightFromDiskReqInput,
117
69
  UpdateWeightFromDiskReqOutput,
118
- UpdateWeightsFromDistributedReqInput,
119
- UpdateWeightsFromDistributedReqOutput,
120
- UpdateWeightsFromTensorReqInput,
121
- UpdateWeightsFromTensorReqOutput,
70
+ WatchLoadUpdateReq,
122
71
  )
123
72
  from sglang.srt.managers.mm_utils import TensorTransportMode
124
73
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
125
74
  from sglang.srt.managers.scheduler import is_health_check_generate_req
126
75
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
76
+ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
127
77
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
128
78
  from sglang.srt.sampling.sampling_params import SamplingParams
129
79
  from sglang.srt.server_args import PortArgs, ServerArgs
80
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
81
+ from sglang.srt.tracing.trace import (
82
+ trace_get_proc_propagate_context,
83
+ trace_req_finish,
84
+ trace_req_start,
85
+ trace_slice_end,
86
+ trace_slice_start,
87
+ )
130
88
  from sglang.srt.utils import (
131
89
  configure_gc_warning,
132
90
  dataclass_to_string_truncated,
@@ -136,6 +94,11 @@ from sglang.srt.utils import (
136
94
  get_zmq_socket,
137
95
  kill_process_tree,
138
96
  )
97
+ from sglang.srt.utils.hf_transformers_utils import (
98
+ get_processor,
99
+ get_tokenizer,
100
+ get_tokenizer_from_processor,
101
+ )
139
102
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
140
103
 
141
104
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -180,7 +143,7 @@ class ReqState:
180
143
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
181
144
 
182
145
 
183
- class TokenizerManager:
146
+ class TokenizerManager(TokenizerCommunicatorMixin):
184
147
  """TokenizerManager is a process that tokenizes the text."""
185
148
 
186
149
  def __init__(
@@ -199,6 +162,7 @@ class TokenizerManager:
199
162
  else None
200
163
  )
201
164
  self.crash_dump_folder = server_args.crash_dump_folder
165
+ self.enable_trace = server_args.enable_trace
202
166
 
203
167
  # Read model args
204
168
  self.model_path = server_args.model_path
@@ -210,8 +174,19 @@ class TokenizerManager:
210
174
  self.image_token_id = self.model_config.image_token_id
211
175
  self.max_req_input_len = None # Will be set later in engine.py
212
176
 
177
+ speculative_algorithm = SpeculativeAlgorithm.from_string(
178
+ server_args.speculative_algorithm
179
+ )
180
+ self.reserve_input_token_num = (
181
+ 0
182
+ if speculative_algorithm.is_none()
183
+ else server_args.speculative_num_draft_tokens
184
+ )
185
+ # Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
186
+ self.multi_item_delimiter_text = None
187
+
213
188
  if self.model_config.is_multimodal:
214
- import_processors()
189
+ import_processors("sglang.srt.multimodal.processors")
215
190
  try:
216
191
  _processor = get_processor(
217
192
  server_args.tokenizer_path,
@@ -250,6 +225,7 @@ class TokenizerManager:
250
225
  self.processor = _processor
251
226
  self.tokenizer = get_tokenizer_from_processor(self.processor)
252
227
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
228
+ self._initialize_multi_item_delimiter_text()
253
229
  else:
254
230
  self.mm_processor = self.processor = None
255
231
 
@@ -262,6 +238,19 @@ class TokenizerManager:
262
238
  trust_remote_code=server_args.trust_remote_code,
263
239
  revision=server_args.revision,
264
240
  )
241
+ self._initialize_multi_item_delimiter_text()
242
+ # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
243
+ if (
244
+ server_args.enable_dynamic_batch_tokenizer
245
+ and not server_args.skip_tokenizer_init
246
+ ):
247
+ self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
248
+ self.tokenizer,
249
+ max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
250
+ batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
251
+ )
252
+ else:
253
+ self.async_dynamic_batch_tokenizer = None
265
254
 
266
255
  # Init inter-process communication
267
256
  context = zmq.asyncio.Context(2)
@@ -319,8 +308,10 @@ class TokenizerManager:
319
308
  # LoRA updates and inference to overlap.
320
309
  self.lora_update_lock = asyncio.Lock()
321
310
 
322
- # For PD disaggregtion
323
- self.init_disaggregation()
311
+ self.disaggregation_mode = DisaggregationMode(
312
+ self.server_args.disaggregation_mode
313
+ )
314
+ self.bootstrap_server = start_disagg_service(self.server_args)
324
315
 
325
316
  # For load balancing
326
317
  self.current_load = 0
@@ -328,12 +319,16 @@ class TokenizerManager:
328
319
 
329
320
  # Metrics
330
321
  if self.enable_metrics:
322
+ labels = {
323
+ "model_name": self.server_args.served_model_name,
324
+ # TODO: Add lora name/path in the future,
325
+ }
326
+ if server_args.tokenizer_metrics_allowed_custom_labels:
327
+ for label in server_args.tokenizer_metrics_allowed_custom_labels:
328
+ labels[label] = ""
331
329
  self.metrics_collector = TokenizerMetricsCollector(
332
330
  server_args=server_args,
333
- labels={
334
- "model_name": self.server_args.served_model_name,
335
- # TODO: Add lora name/path in the future,
336
- },
331
+ labels=labels,
337
332
  bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
338
333
  bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
339
334
  bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
@@ -344,58 +339,14 @@ class TokenizerManager:
344
339
  if self.server_args.gc_warning_threshold_secs > 0.0:
345
340
  configure_gc_warning(self.server_args.gc_warning_threshold_secs)
346
341
 
347
- # Communicators
348
- self.init_weights_update_group_communicator = _Communicator(
349
- self.send_to_scheduler, server_args.dp_size
350
- )
351
- self.update_weights_from_distributed_communicator = _Communicator(
352
- self.send_to_scheduler, server_args.dp_size
353
- )
354
- self.update_weights_from_tensor_communicator = _Communicator(
355
- self.send_to_scheduler, server_args.dp_size
356
- )
357
- self.get_weights_by_name_communicator = _Communicator(
358
- self.send_to_scheduler, server_args.dp_size
359
- )
360
- self.release_memory_occupation_communicator = _Communicator(
361
- self.send_to_scheduler, server_args.dp_size
362
- )
363
- self.resume_memory_occupation_communicator = _Communicator(
364
- self.send_to_scheduler, server_args.dp_size
365
- )
366
- self.slow_down_communicator = _Communicator(
367
- self.send_to_scheduler, server_args.dp_size
368
- )
369
- self.flush_cache_communicator = _Communicator(
370
- self.send_to_scheduler, server_args.dp_size
371
- )
372
- self.clear_hicache_storage_communicator = _Communicator(
373
- self.send_to_scheduler, server_args.dp_size
374
- )
375
- self.profile_communicator = _Communicator(
376
- self.send_to_scheduler, server_args.dp_size
377
- )
378
- self.get_internal_state_communicator = _Communicator(
379
- self.send_to_scheduler, server_args.dp_size
380
- )
381
- self.set_internal_state_communicator = _Communicator(
382
- self.send_to_scheduler, server_args.dp_size
383
- )
384
- self.expert_distribution_communicator = _Communicator(
385
- self.send_to_scheduler, server_args.dp_size
386
- )
387
- self.update_lora_adapter_communicator = _Communicator(
388
- self.send_to_scheduler, server_args.dp_size
389
- )
390
-
391
342
  self._result_dispatcher = TypeBasedDispatcher(
392
343
  [
393
344
  (
394
345
  (
395
- BatchStrOut,
396
- BatchEmbeddingOut,
397
- BatchTokenIDOut,
398
- BatchMultimodalOut,
346
+ BatchStrOutput,
347
+ BatchEmbeddingOutput,
348
+ BatchTokenIDOutput,
349
+ BatchMultimodalOutput,
399
350
  ),
400
351
  self._handle_batch_output,
401
352
  ),
@@ -405,100 +356,15 @@ class TokenizerManager:
405
356
  UpdateWeightFromDiskReqOutput,
406
357
  self._handle_update_weights_from_disk_req_output,
407
358
  ),
408
- (
409
- InitWeightsUpdateGroupReqOutput,
410
- self.init_weights_update_group_communicator.handle_recv,
411
- ),
412
- (
413
- UpdateWeightsFromDistributedReqOutput,
414
- self.update_weights_from_distributed_communicator.handle_recv,
415
- ),
416
- (
417
- UpdateWeightsFromTensorReqOutput,
418
- self.update_weights_from_tensor_communicator.handle_recv,
419
- ),
420
- (
421
- GetWeightsByNameReqOutput,
422
- self.get_weights_by_name_communicator.handle_recv,
423
- ),
424
- (
425
- ReleaseMemoryOccupationReqOutput,
426
- self.release_memory_occupation_communicator.handle_recv,
427
- ),
428
- (
429
- ResumeMemoryOccupationReqOutput,
430
- self.resume_memory_occupation_communicator.handle_recv,
431
- ),
432
- (
433
- SlowDownReqOutput,
434
- self.slow_down_communicator.handle_recv,
435
- ),
436
- (
437
- ClearHiCacheReqOutput,
438
- self.clear_hicache_storage_communicator.handle_recv,
439
- ),
440
- (
441
- FlushCacheReqOutput,
442
- self.flush_cache_communicator.handle_recv,
443
- ),
444
- (
445
- ProfileReqOutput,
446
- self.profile_communicator.handle_recv,
447
- ),
448
359
  (
449
360
  FreezeGCReq,
450
361
  lambda x: None,
451
362
  ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
452
- (
453
- GetInternalStateReqOutput,
454
- self.get_internal_state_communicator.handle_recv,
455
- ),
456
- (
457
- SetInternalStateReqOutput,
458
- self.set_internal_state_communicator.handle_recv,
459
- ),
460
- (
461
- ExpertDistributionReqOutput,
462
- self.expert_distribution_communicator.handle_recv,
463
- ),
464
- (
465
- LoRAUpdateResult,
466
- self.update_lora_adapter_communicator.handle_recv,
467
- ),
468
363
  (HealthCheckOutput, lambda x: None),
469
364
  ]
470
365
  )
471
366
 
472
- def init_disaggregation(self):
473
- self.disaggregation_mode = DisaggregationMode(
474
- self.server_args.disaggregation_mode
475
- )
476
- self.disaggregation_transfer_backend = TransferBackend(
477
- self.server_args.disaggregation_transfer_backend
478
- )
479
- # Start kv boostrap server on prefill
480
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
481
- # only start bootstrap server on prefill tm
482
- kv_bootstrap_server_class = get_kv_class(
483
- self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
484
- )
485
- self.bootstrap_server = kv_bootstrap_server_class(
486
- self.server_args.disaggregation_bootstrap_port
487
- )
488
- is_create_store = (
489
- self.server_args.node_rank == 0
490
- and self.server_args.disaggregation_transfer_backend == "ascend"
491
- )
492
- if is_create_store:
493
- try:
494
- from mf_adapter import create_config_store
495
-
496
- ascend_url = os.getenv("ASCEND_MF_STORE_URL")
497
- create_config_store(ascend_url)
498
- except Exception as e:
499
- error_message = f"Failed create mf store, invalid ascend_url."
500
- error_message += f" With exception {e}"
501
- raise error_message
367
+ self.init_communicators(server_args)
502
368
 
503
369
  async def generate_request(
504
370
  self,
@@ -518,6 +384,9 @@ class TokenizerManager:
518
384
  # If it's a single value, add worker_id prefix
519
385
  obj.rid = f"{self.worker_id}_{obj.rid}"
520
386
 
387
+ if self.enable_trace:
388
+ self._trace_request_start(obj, created_time)
389
+
521
390
  if self.log_requests:
522
391
  max_length, skip_names, _ = self.log_request_metadata
523
392
  logger.info(
@@ -543,6 +412,144 @@ class TokenizerManager:
543
412
  ):
544
413
  yield response
545
414
 
415
+ def _detect_input_format(
416
+ self, texts: Union[str, List[str]], is_cross_encoder: bool
417
+ ) -> str:
418
+ """Detect the format of input texts for proper tokenization handling.
419
+
420
+ Returns:
421
+ - "single_string": Regular single text like "Hello world"
422
+ - "batch_strings": Regular batch like ["Hello", "World"]
423
+ - "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
424
+ """
425
+ if isinstance(texts, str):
426
+ return "single_string"
427
+
428
+ if (
429
+ is_cross_encoder
430
+ and len(texts) > 0
431
+ and isinstance(texts[0], list)
432
+ and len(texts[0]) == 2
433
+ ):
434
+ return "cross_encoder_pairs"
435
+
436
+ return "batch_strings"
437
+
438
+ def _prepare_tokenizer_input(
439
+ self, texts: Union[str, List[str]], input_format: str
440
+ ) -> Union[List[str], List[List[str]]]:
441
+ """Prepare input for the tokenizer based on detected format."""
442
+ if input_format == "single_string":
443
+ return [texts] # Wrap single string for batch processing
444
+ elif input_format == "cross_encoder_pairs":
445
+ return texts # Already in correct format: [["query", "doc"]]
446
+ else: # batch_strings
447
+ return texts # Already in correct format: ["text1", "text2"]
448
+
449
+ def _extract_tokenizer_results(
450
+ self,
451
+ input_ids: List[List[int]],
452
+ token_type_ids: Optional[List[List[int]]],
453
+ input_format: str,
454
+ original_batch_size: int,
455
+ ) -> Union[
456
+ Tuple[List[int], Optional[List[int]]],
457
+ Tuple[List[List[int]], Optional[List[List[int]]]],
458
+ ]:
459
+ """Extract results from tokenizer output based on input format."""
460
+
461
+ # For single inputs (string or single cross-encoder pair), extract first element
462
+ if (
463
+ input_format in ["single_string", "cross_encoder_pairs"]
464
+ and original_batch_size == 1
465
+ ):
466
+ single_input_ids = input_ids[0] if input_ids else []
467
+ single_token_type_ids = token_type_ids[0] if token_type_ids else None
468
+ return single_input_ids, single_token_type_ids
469
+
470
+ # For true batches, return as-is
471
+ return input_ids, token_type_ids
472
+
473
+ async def _tokenize_texts(
474
+ self, texts: Union[str, List[str]], is_cross_encoder: bool = False
475
+ ) -> Union[
476
+ Tuple[List[int], Optional[List[int]]],
477
+ Tuple[List[List[int]], Optional[List[List[int]]]],
478
+ ]:
479
+ """
480
+ Tokenize text(s) using the appropriate tokenizer strategy.
481
+
482
+ This method handles multiple input formats and chooses between async dynamic
483
+ batch tokenizer (for single texts only) and regular tokenizer.
484
+
485
+ Args:
486
+ texts: Text input in various formats:
487
+
488
+ Regular cases:
489
+ - Single string: "How are you?"
490
+ - Batch of strings: ["Hello", "World", "How are you?"]
491
+
492
+ Cross-encoder cases (sentence pairs for similarity/ranking):
493
+ - Single pair: [["query text", "document text"]]
494
+ - Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
495
+
496
+ is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
497
+ Enables proper handling of sentence pairs with segment IDs.
498
+
499
+ Returns:
500
+ Single input cases:
501
+ Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
502
+ Example: ([101, 2129, 102], [0, 0, 0]) for single text
503
+ Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
504
+
505
+ Batch input cases:
506
+ Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
507
+ Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
508
+
509
+ Note: token_type_ids is None unless is_cross_encoder=True.
510
+ """
511
+ if not texts or self.tokenizer is None:
512
+ raise ValueError("texts cannot be empty and tokenizer must be initialized")
513
+
514
+ # Step 1: Detect input format and prepare for tokenization
515
+ input_format = self._detect_input_format(texts, is_cross_encoder)
516
+ tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
517
+ original_batch_size = len(texts) if not isinstance(texts, str) else 1
518
+
519
+ # Step 2: Set up tokenizer arguments
520
+ tokenizer_kwargs = (
521
+ {"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
522
+ )
523
+
524
+ # Step 3: Choose tokenization strategy
525
+ use_async_tokenizer = (
526
+ self.async_dynamic_batch_tokenizer is not None
527
+ and input_format == "single_string"
528
+ )
529
+
530
+ if use_async_tokenizer:
531
+ logger.debug("Using async dynamic batch tokenizer for single text")
532
+ result = await self.async_dynamic_batch_tokenizer.encode(
533
+ tokenizer_input[0], **tokenizer_kwargs
534
+ )
535
+ # Convert to batch format for consistency
536
+ input_ids = [result["input_ids"]]
537
+ token_type_ids = (
538
+ [result["token_type_ids"]]
539
+ if is_cross_encoder and result.get("token_type_ids")
540
+ else None
541
+ )
542
+ else:
543
+ logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
544
+ encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
545
+ input_ids = encoded["input_ids"]
546
+ token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
547
+
548
+ # Step 4: Extract results based on input format
549
+ return self._extract_tokenizer_results(
550
+ input_ids, token_type_ids, input_format, original_batch_size
551
+ )
552
+
546
553
  async def _tokenize_one_request(
547
554
  self,
548
555
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -573,14 +580,10 @@ class TokenizerManager:
573
580
  "accept text prompts. Please provide input_ids or re-initialize "
574
581
  "the engine with skip_tokenizer_init=False."
575
582
  )
576
- encoded = self.tokenizer(
577
- input_text, return_token_type_ids=is_cross_encoder_request
578
- )
579
583
 
580
- input_ids = encoded["input_ids"]
581
- if is_cross_encoder_request:
582
- input_ids = encoded["input_ids"][0]
583
- token_type_ids = encoded.get("token_type_ids", [None])[0]
584
+ input_ids, token_type_ids = await self._tokenize_texts(
585
+ input_text, is_cross_encoder_request
586
+ )
584
587
 
585
588
  if self.mm_processor and obj.contains_mm_input():
586
589
  if not isinstance(obj.image_data, list):
@@ -600,6 +603,7 @@ class TokenizerManager:
600
603
  mm_inputs = None
601
604
 
602
605
  self._validate_one_request(obj, input_ids)
606
+ trace_slice_end("tokenize", obj.rid)
603
607
  return self._create_tokenized_object(
604
608
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
605
609
  )
@@ -612,6 +616,7 @@ class TokenizerManager:
612
616
  _max_req_len = self.context_len
613
617
 
614
618
  input_token_num = len(input_ids) if input_ids is not None else 0
619
+ input_token_num += self.reserve_input_token_num
615
620
  if input_token_num >= self.context_len:
616
621
  if self.server_args.allow_auto_truncate:
617
622
  logger.warning(
@@ -674,7 +679,7 @@ class TokenizerManager:
674
679
  ):
675
680
  raise ValueError(
676
681
  "The server is not configured to enable custom logit processor. "
677
- "Please set `--enable-custom-logits-processor` to enable this feature."
682
+ "Please set `--enable-custom-logit-processor` to enable this feature."
678
683
  )
679
684
 
680
685
  def _validate_input_ids_in_vocab(
@@ -713,7 +718,6 @@ class TokenizerManager:
713
718
  )
714
719
 
715
720
  tokenized_obj = TokenizedGenerateReqInput(
716
- obj.rid,
717
721
  input_text,
718
722
  input_ids,
719
723
  mm_inputs,
@@ -723,6 +727,7 @@ class TokenizerManager:
723
727
  obj.top_logprobs_num,
724
728
  obj.token_ids_logprob,
725
729
  obj.stream,
730
+ rid=obj.rid,
726
731
  bootstrap_host=obj.bootstrap_host,
727
732
  bootstrap_port=obj.bootstrap_port,
728
733
  bootstrap_room=obj.bootstrap_room,
@@ -732,15 +737,18 @@ class TokenizerManager:
732
737
  custom_logit_processor=obj.custom_logit_processor,
733
738
  return_hidden_states=obj.return_hidden_states,
734
739
  data_parallel_rank=obj.data_parallel_rank,
740
+ priority=obj.priority,
741
+ extra_key=obj.extra_key,
735
742
  )
736
743
  elif isinstance(obj, EmbeddingReqInput):
737
744
  tokenized_obj = TokenizedEmbeddingReqInput(
738
- obj.rid,
739
745
  input_text,
740
746
  input_ids,
741
747
  mm_inputs,
742
748
  token_type_ids,
743
749
  sampling_params,
750
+ rid=obj.rid,
751
+ priority=obj.priority,
744
752
  )
745
753
 
746
754
  return tokenized_obj
@@ -755,19 +763,30 @@ class TokenizerManager:
755
763
  requests = [obj[i] for i in range(batch_size)]
756
764
  texts = [req.text for req in requests]
757
765
 
758
- # Batch tokenize all texts
759
- encoded = self.tokenizer(texts)
760
- input_ids_list = encoded["input_ids"]
766
+ # Check if any request is a cross-encoder request
767
+ is_cross_encoder_request = any(
768
+ isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
769
+ for req in requests
770
+ )
771
+
772
+ # Batch tokenize all texts using unified method
773
+ input_ids_list, token_type_ids_list = await self._tokenize_texts(
774
+ texts, is_cross_encoder_request
775
+ )
761
776
 
762
777
  # Process all requests
763
778
  tokenized_objs = []
764
779
  for i, req in enumerate(requests):
765
780
  self._validate_one_request(obj[i], input_ids_list[i])
781
+ token_type_ids = (
782
+ token_type_ids_list[i] if token_type_ids_list is not None else None
783
+ )
766
784
  tokenized_objs.append(
767
785
  self._create_tokenized_object(
768
- req, req.text, input_ids_list[i], None, None
786
+ req, req.text, input_ids_list[i], None, None, token_type_ids
769
787
  )
770
788
  )
789
+ trace_slice_end("tokenize", req.rid)
771
790
  logger.debug(f"Completed batch processing for {batch_size} requests")
772
791
  return tokenized_objs
773
792
 
@@ -795,9 +814,12 @@ class TokenizerManager:
795
814
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
796
815
  created_time: Optional[float] = None,
797
816
  ):
817
+ trace_slice_start("dispatch", obj.rid)
818
+ tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
798
819
  self.send_to_scheduler.send_pyobj(tokenized_obj)
799
820
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
800
821
  self.rid_to_state[obj.rid] = state
822
+ trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
801
823
  return state
802
824
 
803
825
  def _send_batch_request(
@@ -1015,73 +1037,16 @@ class TokenizerManager:
1015
1037
  except StopAsyncIteration:
1016
1038
  pass
1017
1039
 
1018
- async def flush_cache(self) -> FlushCacheReqOutput:
1019
- return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
1020
-
1021
- async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
1022
- """Clear the hierarchical cache storage."""
1023
- # Delegate to the scheduler to handle HiCacheStorage clearing
1024
- return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
1025
- 0
1026
- ]
1027
-
1028
1040
  def abort_request(self, rid: str = "", abort_all: bool = False):
1029
1041
  if not abort_all and rid not in self.rid_to_state:
1030
1042
  return
1031
- req = AbortReq(rid, abort_all)
1043
+ req = AbortReq(rid=rid, abort_all=abort_all)
1032
1044
  self.send_to_scheduler.send_pyobj(req)
1033
-
1034
1045
  if self.enable_metrics:
1035
- self.metrics_collector.observe_one_aborted_request()
1036
-
1037
- async def start_profile(
1038
- self,
1039
- output_dir: Optional[str] = None,
1040
- start_step: Optional[int] = None,
1041
- num_steps: Optional[int] = None,
1042
- activities: Optional[List[str]] = None,
1043
- with_stack: Optional[bool] = None,
1044
- record_shapes: Optional[bool] = None,
1045
- profile_by_stage: bool = False,
1046
- ):
1047
- self.auto_create_handle_loop()
1048
- env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
1049
- with_stack = False if with_stack is False or env_with_stack is False else True
1050
- req = ProfileReq(
1051
- type=ProfileReqType.START_PROFILE,
1052
- output_dir=output_dir,
1053
- start_step=start_step,
1054
- num_steps=num_steps,
1055
- activities=activities,
1056
- with_stack=with_stack,
1057
- record_shapes=record_shapes,
1058
- profile_by_stage=profile_by_stage,
1059
- profile_id=str(time.time()),
1060
- )
1061
- return await self._execute_profile(req)
1062
-
1063
- async def stop_profile(self):
1064
- self.auto_create_handle_loop()
1065
- req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
1066
- return await self._execute_profile(req)
1067
-
1068
- async def _execute_profile(self, req: ProfileReq):
1069
- result = (await self.profile_communicator(req))[0]
1070
- if not result.success:
1071
- raise RuntimeError(result.message)
1072
- return result
1073
-
1074
- async def start_expert_distribution_record(self):
1075
- self.auto_create_handle_loop()
1076
- await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
1077
-
1078
- async def stop_expert_distribution_record(self):
1079
- self.auto_create_handle_loop()
1080
- await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
1081
-
1082
- async def dump_expert_distribution_record(self):
1083
- self.auto_create_handle_loop()
1084
- await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
1046
+ # TODO: also use custom_labels from the request
1047
+ self.metrics_collector.observe_one_aborted_request(
1048
+ self.metrics_collector.labels
1049
+ )
1085
1050
 
1086
1051
  async def pause_generation(self):
1087
1052
  async with self.is_pause_cond:
@@ -1118,7 +1083,7 @@ class TokenizerManager:
1118
1083
  self, obj: UpdateWeightFromDiskReqInput
1119
1084
  ) -> Tuple[bool, str]:
1120
1085
  if self.server_args.tokenizer_worker_num > 1:
1121
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1086
+ obj = MultiTokenizerWrapper(self.worker_id, obj)
1122
1087
  self.send_to_scheduler.send_pyobj(obj)
1123
1088
  self.model_update_result = asyncio.Future()
1124
1089
  if self.server_args.dp_size == 1:
@@ -1143,291 +1108,6 @@ class TokenizerManager:
1143
1108
  all_paused_requests = [r.num_paused_requests for r in result]
1144
1109
  return all_success, all_message, all_paused_requests
1145
1110
 
1146
- async def init_weights_update_group(
1147
- self,
1148
- obj: InitWeightsUpdateGroupReqInput,
1149
- request: Optional[fastapi.Request] = None,
1150
- ) -> Tuple[bool, str]:
1151
- self.auto_create_handle_loop()
1152
- assert (
1153
- self.server_args.dp_size == 1
1154
- ), "dp_size must be 1 for init parameter update group"
1155
- result = (await self.init_weights_update_group_communicator(obj))[0]
1156
- return result.success, result.message
1157
-
1158
- async def update_weights_from_distributed(
1159
- self,
1160
- obj: UpdateWeightsFromDistributedReqInput,
1161
- request: Optional[fastapi.Request] = None,
1162
- ) -> Tuple[bool, str]:
1163
- self.auto_create_handle_loop()
1164
- assert (
1165
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1166
- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
1167
-
1168
- if obj.abort_all_requests:
1169
- self.abort_request(abort_all=True)
1170
-
1171
- # This means that weight sync
1172
- # cannot run while requests are in progress.
1173
- async with self.model_update_lock.writer_lock:
1174
- result = (await self.update_weights_from_distributed_communicator(obj))[0]
1175
- return result.success, result.message
1176
-
1177
- async def update_weights_from_tensor(
1178
- self,
1179
- obj: UpdateWeightsFromTensorReqInput,
1180
- request: Optional[fastapi.Request] = None,
1181
- ) -> Tuple[bool, str]:
1182
- self.auto_create_handle_loop()
1183
- assert (
1184
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1185
- ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
1186
-
1187
- if obj.abort_all_requests:
1188
- self.abort_request(abort_all=True)
1189
-
1190
- # This means that weight sync
1191
- # cannot run while requests are in progress.
1192
- async with self.model_update_lock.writer_lock:
1193
- result = (await self.update_weights_from_tensor_communicator(obj))[0]
1194
- return result.success, result.message
1195
-
1196
- async def load_lora_adapter(
1197
- self,
1198
- obj: LoadLoRAAdapterReqInput,
1199
- _: Optional[fastapi.Request] = None,
1200
- ) -> LoadLoRAAdapterReqOutput:
1201
- self.auto_create_handle_loop()
1202
-
1203
- try:
1204
- if not self.server_args.enable_lora:
1205
- raise ValueError(
1206
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1207
- )
1208
-
1209
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1210
- # with dp_size > 1.
1211
- assert (
1212
- self.server_args.dp_size == 1
1213
- ), "dp_size must be 1 for dynamic lora loading"
1214
- logger.info(
1215
- "Start load Lora adapter. Lora name=%s, path=%s",
1216
- obj.lora_name,
1217
- obj.lora_path,
1218
- )
1219
-
1220
- async with self.lora_update_lock:
1221
- if (
1222
- self.server_args.max_loaded_loras is not None
1223
- and self.lora_registry.num_registered_loras
1224
- >= self.server_args.max_loaded_loras
1225
- ):
1226
- raise ValueError(
1227
- f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1228
- f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1229
- "Please unload some LoRA adapters before loading new ones."
1230
- )
1231
-
1232
- # Generate new uniquely identifiable LoRARef object.
1233
- new_adapter = LoRARef(
1234
- lora_name=obj.lora_name,
1235
- lora_path=obj.lora_path,
1236
- pinned=obj.pinned,
1237
- )
1238
-
1239
- # Trigger the actual loading operation at the backend processes.
1240
- obj.lora_id = new_adapter.lora_id
1241
- result = (await self.update_lora_adapter_communicator(obj))[0]
1242
-
1243
- # Register the LoRA adapter only after loading is successful.
1244
- if result.success:
1245
- await self.lora_registry.register(new_adapter)
1246
-
1247
- return result
1248
- except ValueError as e:
1249
- return LoadLoRAAdapterReqOutput(
1250
- success=False,
1251
- error_message=str(e),
1252
- )
1253
-
1254
- async def unload_lora_adapter(
1255
- self,
1256
- obj: UnloadLoRAAdapterReqInput,
1257
- _: Optional[fastapi.Request] = None,
1258
- ) -> UnloadLoRAAdapterReqOutput:
1259
- self.auto_create_handle_loop()
1260
-
1261
- try:
1262
- if not self.server_args.enable_lora:
1263
- raise ValueError(
1264
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1265
- )
1266
-
1267
- assert (
1268
- obj.lora_name is not None
1269
- ), "lora_name must be provided to unload LoRA adapter"
1270
-
1271
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1272
- # with dp_size > 1.
1273
- assert (
1274
- self.server_args.dp_size == 1
1275
- ), "dp_size must be 1 for dynamic lora loading"
1276
- logger.info(
1277
- "Start unload Lora adapter. Lora name=%s",
1278
- obj.lora_name,
1279
- )
1280
-
1281
- async with self.lora_update_lock:
1282
- # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1283
- # from being started.
1284
- lora_id = await self.lora_registry.unregister(obj.lora_name)
1285
- obj.lora_id = lora_id
1286
-
1287
- # Initiate the actual unloading operation at the backend processes only after all
1288
- # ongoing requests using this LoRA adapter are finished.
1289
- await self.lora_registry.wait_for_unload(lora_id)
1290
- result = (await self.update_lora_adapter_communicator(obj))[0]
1291
-
1292
- return result
1293
- except ValueError as e:
1294
- return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1295
-
1296
- async def get_weights_by_name(
1297
- self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
1298
- ):
1299
- self.auto_create_handle_loop()
1300
- results = await self.get_weights_by_name_communicator(obj)
1301
- all_parameters = [r.parameter for r in results]
1302
- if self.server_args.dp_size == 1:
1303
- return all_parameters[0]
1304
- else:
1305
- return all_parameters
1306
-
1307
- async def release_memory_occupation(
1308
- self,
1309
- obj: ReleaseMemoryOccupationReqInput,
1310
- request: Optional[fastapi.Request] = None,
1311
- ):
1312
- self.auto_create_handle_loop()
1313
- await self.release_memory_occupation_communicator(obj)
1314
-
1315
- async def resume_memory_occupation(
1316
- self,
1317
- obj: ResumeMemoryOccupationReqInput,
1318
- request: Optional[fastapi.Request] = None,
1319
- ):
1320
- self.auto_create_handle_loop()
1321
- await self.resume_memory_occupation_communicator(obj)
1322
-
1323
- async def slow_down(
1324
- self,
1325
- obj: SlowDownReqInput,
1326
- request: Optional[fastapi.Request] = None,
1327
- ):
1328
- self.auto_create_handle_loop()
1329
- await self.slow_down_communicator(obj)
1330
-
1331
- async def open_session(
1332
- self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
1333
- ):
1334
- self.auto_create_handle_loop()
1335
-
1336
- if obj.session_id is None:
1337
- obj.session_id = uuid.uuid4().hex
1338
- elif obj.session_id in self.session_futures:
1339
- return None
1340
-
1341
- if self.server_args.tokenizer_worker_num > 1:
1342
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1343
- self.send_to_scheduler.send_pyobj(obj)
1344
-
1345
- self.session_futures[obj.session_id] = asyncio.Future()
1346
- session_id = await self.session_futures[obj.session_id]
1347
- del self.session_futures[obj.session_id]
1348
- return session_id
1349
-
1350
- async def close_session(
1351
- self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
1352
- ):
1353
- await self.send_to_scheduler.send_pyobj(obj)
1354
-
1355
- async def get_internal_state(self) -> List[Dict[Any, Any]]:
1356
- req = GetInternalStateReq()
1357
- responses: List[GetInternalStateReqOutput] = (
1358
- await self.get_internal_state_communicator(req)
1359
- )
1360
- # Many DP ranks
1361
- return [res.internal_state for res in responses]
1362
-
1363
- async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
1364
- responses: List[SetInternalStateReqOutput] = (
1365
- await self.set_internal_state_communicator(obj)
1366
- )
1367
- return [res.updated for res in responses]
1368
-
1369
- async def get_load(self) -> dict:
1370
- # TODO(lsyin): fake load report server
1371
- if not self.current_load_lock.locked():
1372
- async with self.current_load_lock:
1373
- internal_state = await self.get_internal_state()
1374
- self.current_load = internal_state[0]["load"]
1375
- return {"load": self.current_load}
1376
-
1377
- def get_log_request_metadata(self):
1378
- max_length = None
1379
- skip_names = None
1380
- out_skip_names = None
1381
- if self.log_requests:
1382
- if self.log_requests_level == 0:
1383
- max_length = 1 << 30
1384
- skip_names = set(
1385
- [
1386
- "text",
1387
- "input_ids",
1388
- "input_embeds",
1389
- "image_data",
1390
- "audio_data",
1391
- "lora_path",
1392
- "sampling_params",
1393
- ]
1394
- )
1395
- out_skip_names = set(
1396
- [
1397
- "text",
1398
- "output_ids",
1399
- "embedding",
1400
- ]
1401
- )
1402
- elif self.log_requests_level == 1:
1403
- max_length = 1 << 30
1404
- skip_names = set(
1405
- [
1406
- "text",
1407
- "input_ids",
1408
- "input_embeds",
1409
- "image_data",
1410
- "audio_data",
1411
- "lora_path",
1412
- ]
1413
- )
1414
- out_skip_names = set(
1415
- [
1416
- "text",
1417
- "output_ids",
1418
- "embedding",
1419
- ]
1420
- )
1421
- elif self.log_requests_level == 2:
1422
- max_length = 2048
1423
- elif self.log_requests_level == 3:
1424
- max_length = 1 << 30
1425
- else:
1426
- raise ValueError(
1427
- f"Invalid --log-requests-level: {self.log_requests_level=}"
1428
- )
1429
- return max_length, skip_names, out_skip_names
1430
-
1431
1111
  def configure_logging(self, obj: ConfigureLoggingReq):
1432
1112
  if obj.log_requests is not None:
1433
1113
  self.log_requests = obj.log_requests
@@ -1492,6 +1172,9 @@ class TokenizerManager:
1492
1172
  self.asyncio_tasks.add(
1493
1173
  loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
1494
1174
  )
1175
+ self.asyncio_tasks.add(
1176
+ loop.create_task(print_exception_wrapper(self.watch_load_thread))
1177
+ )
1495
1178
 
1496
1179
  def dump_requests_before_crash(self):
1497
1180
  if self.crash_dump_performed:
@@ -1583,12 +1266,12 @@ class TokenizerManager:
1583
1266
  # Drain requests
1584
1267
  while True:
1585
1268
  remain_num_req = len(self.rid_to_state)
1269
+ remaining_rids = list(self.rid_to_state.keys())
1586
1270
 
1587
1271
  if self.server_status == ServerStatus.UnHealthy:
1588
1272
  # if health check failed, we should exit immediately
1589
1273
  logger.error(
1590
- "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1591
- remain_num_req,
1274
+ "Signal SIGTERM received while health check failed. Force exiting."
1592
1275
  )
1593
1276
  self.dump_requests_before_crash()
1594
1277
  break
@@ -1596,13 +1279,12 @@ class TokenizerManager:
1596
1279
  elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
1597
1280
  # if force shutdown flag set, exit immediately
1598
1281
  logger.error(
1599
- "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
1600
- remain_num_req,
1282
+ "Signal SIGTERM received while force shutdown flag set. Force exiting."
1601
1283
  )
1602
1284
  break
1603
1285
 
1604
1286
  logger.info(
1605
- f"Gracefully exiting... remaining number of requests {remain_num_req}"
1287
+ f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
1606
1288
  )
1607
1289
  if remain_num_req > 0:
1608
1290
  await asyncio.sleep(5)
@@ -1623,7 +1305,10 @@ class TokenizerManager:
1623
1305
  def _handle_batch_output(
1624
1306
  self,
1625
1307
  recv_obj: Union[
1626
- BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
1308
+ BatchStrOutput,
1309
+ BatchEmbeddingOutput,
1310
+ BatchMultimodalOutput,
1311
+ BatchTokenIDOutput,
1627
1312
  ],
1628
1313
  ):
1629
1314
  for i, rid in enumerate(recv_obj.rids):
@@ -1657,7 +1342,7 @@ class TokenizerManager:
1657
1342
  i,
1658
1343
  )
1659
1344
 
1660
- if not isinstance(recv_obj, BatchEmbeddingOut):
1345
+ if not isinstance(recv_obj, BatchEmbeddingOutput):
1661
1346
  meta_info.update(
1662
1347
  {
1663
1348
  "completion_tokens": recv_obj.completion_tokens[i],
@@ -1668,7 +1353,7 @@ class TokenizerManager:
1668
1353
  if getattr(recv_obj, "output_hidden_states", None):
1669
1354
  meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
1670
1355
 
1671
- if isinstance(recv_obj, BatchStrOut):
1356
+ if isinstance(recv_obj, BatchStrOutput):
1672
1357
  state.text += recv_obj.output_strs[i]
1673
1358
  if state.obj.stream:
1674
1359
  state.output_ids.extend(recv_obj.output_ids[i])
@@ -1683,7 +1368,7 @@ class TokenizerManager:
1683
1368
  "output_ids": output_token_ids,
1684
1369
  "meta_info": meta_info,
1685
1370
  }
1686
- elif isinstance(recv_obj, BatchTokenIDOut):
1371
+ elif isinstance(recv_obj, BatchTokenIDOutput):
1687
1372
  if self.server_args.stream_output and state.obj.stream:
1688
1373
  state.output_ids.extend(recv_obj.output_ids[i])
1689
1374
  output_token_ids = state.output_ids[state.last_output_offset :]
@@ -1696,10 +1381,10 @@ class TokenizerManager:
1696
1381
  "output_ids": output_token_ids,
1697
1382
  "meta_info": meta_info,
1698
1383
  }
1699
- elif isinstance(recv_obj, BatchMultimodalOut):
1384
+ elif isinstance(recv_obj, BatchMultimodalOutput):
1700
1385
  raise NotImplementedError("BatchMultimodalOut not implemented")
1701
1386
  else:
1702
- assert isinstance(recv_obj, BatchEmbeddingOut)
1387
+ assert isinstance(recv_obj, BatchEmbeddingOutput)
1703
1388
  out_dict = {
1704
1389
  "embedding": recv_obj.embeddings[i],
1705
1390
  "meta_info": meta_info,
@@ -1711,6 +1396,9 @@ class TokenizerManager:
1711
1396
  meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1712
1397
  state.finished_time = time.time()
1713
1398
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1399
+
1400
+ trace_req_finish(rid, ts=int(state.finished_time * 1e9))
1401
+
1714
1402
  del self.rid_to_state[rid]
1715
1403
 
1716
1404
  # Mark ongoing LoRA request as finished.
@@ -1735,7 +1423,7 @@ class TokenizerManager:
1735
1423
  top_logprobs_num: int,
1736
1424
  token_ids_logprob: List[int],
1737
1425
  return_text_in_logprobs: bool,
1738
- recv_obj: BatchStrOut,
1426
+ recv_obj: BatchStrOutput,
1739
1427
  recv_obj_index: int,
1740
1428
  ):
1741
1429
  if recv_obj.input_token_logprobs_val is None:
@@ -1853,13 +1541,19 @@ class TokenizerManager:
1853
1541
  ret.append(None)
1854
1542
  return ret
1855
1543
 
1856
- def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
1544
+ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
1857
1545
  completion_tokens = (
1858
1546
  recv_obj.completion_tokens[i]
1859
1547
  if getattr(recv_obj, "completion_tokens", None)
1860
1548
  else 0
1861
1549
  )
1862
1550
 
1551
+ custom_labels = getattr(state.obj, "custom_labels", None)
1552
+ labels = (
1553
+ {**self.metrics_collector.labels, **custom_labels}
1554
+ if custom_labels
1555
+ else self.metrics_collector.labels
1556
+ )
1863
1557
  if (
1864
1558
  state.first_token_time == 0.0
1865
1559
  and self.disaggregation_mode != DisaggregationMode.PREFILL
@@ -1867,7 +1561,7 @@ class TokenizerManager:
1867
1561
  state.first_token_time = state.last_time = time.time()
1868
1562
  state.last_completion_tokens = completion_tokens
1869
1563
  self.metrics_collector.observe_time_to_first_token(
1870
- state.first_token_time - state.created_time
1564
+ labels, state.first_token_time - state.created_time
1871
1565
  )
1872
1566
  else:
1873
1567
  num_new_tokens = completion_tokens - state.last_completion_tokens
@@ -1875,6 +1569,7 @@ class TokenizerManager:
1875
1569
  new_time = time.time()
1876
1570
  interval = new_time - state.last_time
1877
1571
  self.metrics_collector.observe_inter_token_latency(
1572
+ labels,
1878
1573
  interval,
1879
1574
  num_new_tokens,
1880
1575
  )
@@ -1889,6 +1584,7 @@ class TokenizerManager:
1889
1584
  or state.obj.sampling_params.get("structural_tag", None)
1890
1585
  )
1891
1586
  self.metrics_collector.observe_one_finished_request(
1587
+ labels,
1892
1588
  recv_obj.prompt_tokens[i],
1893
1589
  completion_tokens,
1894
1590
  recv_obj.cached_tokens[i],
@@ -1941,7 +1637,7 @@ class TokenizerManager:
1941
1637
 
1942
1638
  asyncio.create_task(asyncio.to_thread(background_task))
1943
1639
 
1944
- def _handle_abort_req(self, recv_obj):
1640
+ def _handle_abort_req(self, recv_obj: AbortReq):
1945
1641
  if is_health_check_generate_req(recv_obj):
1946
1642
  return
1947
1643
  state = self.rid_to_state[recv_obj.rid]
@@ -1986,6 +1682,201 @@ class TokenizerManager:
1986
1682
  if len(self.model_update_tmp) == self.server_args.dp_size:
1987
1683
  self.model_update_result.set_result(self.model_update_tmp)
1988
1684
 
1685
+ def _initialize_multi_item_delimiter_text(self):
1686
+ """Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
1687
+ if (
1688
+ hasattr(self.server_args, "multi_item_scoring_delimiter")
1689
+ and self.server_args.multi_item_scoring_delimiter is not None
1690
+ and self.tokenizer is not None
1691
+ ):
1692
+ try:
1693
+ self.multi_item_delimiter_text = self.tokenizer.decode(
1694
+ [self.server_args.multi_item_scoring_delimiter],
1695
+ skip_special_tokens=False,
1696
+ )
1697
+ except Exception as e:
1698
+ logger.warning(
1699
+ f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
1700
+ )
1701
+ self.multi_item_delimiter_text = None
1702
+
1703
+ def _build_multi_item_token_sequence(
1704
+ self, query: List[int], items: List[List[int]], delimiter_token_id: int
1705
+ ) -> List[int]:
1706
+ """
1707
+ Build a single token sequence for multi-item scoring.
1708
+ Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1709
+
1710
+ Args:
1711
+ query: Query token IDs
1712
+ items: List of item token ID sequences
1713
+ delimiter_token_id: Token ID to use as delimiter
1714
+
1715
+ Returns:
1716
+ Combined token sequence
1717
+ """
1718
+ combined_sequence = query[:] # Start with query
1719
+
1720
+ for item in items:
1721
+ combined_sequence.append(delimiter_token_id) # Add delimiter
1722
+ combined_sequence.extend(item) # Add item tokens
1723
+
1724
+ # Add final delimiter after the last item for logprob extraction
1725
+ combined_sequence.append(delimiter_token_id)
1726
+
1727
+ return combined_sequence
1728
+
1729
+ def _extract_logprobs_for_tokens(
1730
+ self, logprobs_data: List, label_token_ids: List[int]
1731
+ ) -> Dict[int, float]:
1732
+ """
1733
+ Extract logprobs for specified token IDs from logprobs data.
1734
+
1735
+ Args:
1736
+ logprobs_data: List of (logprob, token_id, text) tuples
1737
+ label_token_ids: Token IDs to extract logprobs for
1738
+
1739
+ Returns:
1740
+ Dictionary mapping token_id to logprob
1741
+ """
1742
+ logprobs = {}
1743
+ if logprobs_data:
1744
+ for logprob, token_id, _ in logprobs_data:
1745
+ if token_id in label_token_ids:
1746
+ logprobs[token_id] = logprob
1747
+ return logprobs
1748
+
1749
+ def _convert_logprobs_to_scores(
1750
+ self,
1751
+ logprobs: Dict[int, float],
1752
+ label_token_ids: List[int],
1753
+ apply_softmax: bool,
1754
+ ) -> List[float]:
1755
+ """
1756
+ Convert logprobs dictionary to ordered score list.
1757
+
1758
+ Args:
1759
+ logprobs: Dictionary mapping token_id to logprob
1760
+ label_token_ids: Token IDs in desired order
1761
+ apply_softmax: Whether to apply softmax normalization
1762
+
1763
+ Returns:
1764
+ List of scores in the same order as label_token_ids
1765
+ """
1766
+ score_list = [
1767
+ logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
1768
+ ]
1769
+
1770
+ if apply_softmax:
1771
+ score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
1772
+ else:
1773
+ # Convert logprobs to probabilities if not using softmax
1774
+ score_list = [
1775
+ math.exp(x) if x != float("-inf") else 0.0 for x in score_list
1776
+ ]
1777
+
1778
+ return score_list
1779
+
1780
+ def _process_multi_item_scoring_results(
1781
+ self,
1782
+ results: Any,
1783
+ items: List,
1784
+ label_token_ids: List[int],
1785
+ apply_softmax: bool,
1786
+ batch_request=None,
1787
+ ) -> List[List[float]]:
1788
+ """
1789
+ Process results from multi-item scoring request.
1790
+ Extracts logprobs at delimiter positions from input_token_ids_logprobs.
1791
+
1792
+ Args:
1793
+ results: Results from generate_request
1794
+ items: List of items being scored
1795
+ label_token_ids: Token IDs to extract scores for
1796
+ apply_softmax: Whether to apply softmax normalization
1797
+ batch_request: The original batch request containing input sequence
1798
+
1799
+ Returns:
1800
+ List of score lists, one for each item
1801
+ """
1802
+ single_result = results[0] if isinstance(results, list) else results
1803
+
1804
+ # For multi-item scoring, logprobs are in input_token_ids_logprobs
1805
+ input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
1806
+
1807
+ if not input_logprobs:
1808
+ raise RuntimeError(
1809
+ f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
1810
+ "This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
1811
+ )
1812
+
1813
+ scores = []
1814
+ num_items = len(items) if isinstance(items, list) else 1
1815
+
1816
+ # Check if we have the expected number of logprobs
1817
+ expected_logprobs_count = num_items + 1
1818
+ if len(input_logprobs) != expected_logprobs_count:
1819
+ raise RuntimeError(
1820
+ f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
1821
+ f"with {num_items} items, but got {len(input_logprobs)}. "
1822
+ f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
1823
+ )
1824
+
1825
+ # Skip the first delimiter (between query and first item) and process remaining delimiter positions
1826
+ # We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
1827
+ start_idx = 1 if len(input_logprobs) > 1 else 0
1828
+
1829
+ # Process logprobs for each item position (excluding first delimiter)
1830
+ for item_idx in range(num_items):
1831
+ logprob_idx = start_idx + item_idx
1832
+ item_logprobs_data = input_logprobs[logprob_idx]
1833
+ logprobs = self._extract_logprobs_for_tokens(
1834
+ item_logprobs_data, label_token_ids
1835
+ )
1836
+ score_list = self._convert_logprobs_to_scores(
1837
+ logprobs, label_token_ids, apply_softmax
1838
+ )
1839
+ scores.append(score_list)
1840
+
1841
+ return scores
1842
+
1843
+ def _process_single_item_scoring_results(
1844
+ self, results: Any, label_token_ids: List[int], apply_softmax: bool
1845
+ ) -> List[List[float]]:
1846
+ """
1847
+ Process results from single-item scoring request.
1848
+ Single-item scoring results are stored in output_token_ids_logprobs.
1849
+
1850
+ Args:
1851
+ results: Results from generate_request
1852
+ label_token_ids: Token IDs to extract scores for
1853
+ apply_softmax: Whether to apply softmax normalization
1854
+
1855
+ Returns:
1856
+ List of score lists, one for each result
1857
+ """
1858
+ scores = []
1859
+
1860
+ for result in results:
1861
+ # For single-item scoring, logprobs are in output_token_ids_logprobs
1862
+ output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
1863
+
1864
+ if not output_logprobs or len(output_logprobs) == 0:
1865
+ raise RuntimeError(
1866
+ f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
1867
+ )
1868
+
1869
+ # Extract logprobs for the first (and only) position
1870
+ logprobs = self._extract_logprobs_for_tokens(
1871
+ output_logprobs[0], label_token_ids
1872
+ )
1873
+ score_list = self._convert_logprobs_to_scores(
1874
+ logprobs, label_token_ids, apply_softmax
1875
+ )
1876
+ scores.append(score_list)
1877
+
1878
+ return scores
1879
+
1989
1880
  async def score_request(
1990
1881
  self,
1991
1882
  query: Optional[Union[str, List[int]]] = None,
@@ -1996,7 +1887,29 @@ class TokenizerManager:
1996
1887
  request: Optional[Any] = None,
1997
1888
  ) -> List[List[float]]:
1998
1889
  """
1999
- See Engine.score() for more details.
1890
+ Score the probability of specified token IDs appearing after the given (query + item) pair.
1891
+
1892
+ This method supports two scoring approaches:
1893
+ 1. Single-Item scoring (default): Process each query+item pair independently
1894
+ 2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
1895
+ multiple items into a single sequence using delimiter for efficient processing.
1896
+ Note: item_first parameter is ignored in multi-item scoring mode since it uses
1897
+ a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1898
+
1899
+ Multi-item scoring works with both text and pre-tokenized inputs:
1900
+ - Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
1901
+ - Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
1902
+
1903
+ Args:
1904
+ query: The query text or pre-tokenized query token IDs
1905
+ items: The item text(s) or pre-tokenized item token IDs
1906
+ label_token_ids: List of token IDs to compute probabilities for
1907
+ apply_softmax: Whether to normalize probabilities using softmax
1908
+ item_first: If True, prepend items to query. Ignored for multi-item scoring.
1909
+ request: Optional FastAPI request object
1910
+
1911
+ Returns:
1912
+ List of lists containing probabilities for each item and each label token
2000
1913
  """
2001
1914
  if label_token_ids is None:
2002
1915
  raise ValueError("label_token_ids must be provided")
@@ -2009,9 +1922,17 @@ class TokenizerManager:
2009
1922
  f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
2010
1923
  )
2011
1924
 
1925
+ # Check if multi-item scoring is enabled by presence of delimiter
1926
+ use_multi_item_scoring = (
1927
+ self.server_args.multi_item_scoring_delimiter is not None
1928
+ and self.multi_item_delimiter_text is not None
1929
+ )
1930
+
2012
1931
  batch_request = GenerateReqInput(
2013
1932
  token_ids_logprob=label_token_ids,
2014
1933
  return_logprob=True,
1934
+ # Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
1935
+ logprob_start_len=0 if use_multi_item_scoring else -1,
2015
1936
  stream=False,
2016
1937
  sampling_params={"max_new_tokens": 0},
2017
1938
  )
@@ -2023,12 +1944,23 @@ class TokenizerManager:
2023
1944
  ):
2024
1945
  # Both query and items are text
2025
1946
  items_list = [items] if isinstance(items, str) else items
2026
- if item_first:
2027
- prompts = [f"{item}{query}" for item in items_list]
2028
- else:
2029
- prompts = [f"{query}{item}" for item in items_list]
2030
1947
 
2031
- batch_request.text = prompts
1948
+ if use_multi_item_scoring:
1949
+ # Multi-item scoring: create single prompt with delimiter text
1950
+ # Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1951
+ # (item_first is ignored for multi-item scoring)
1952
+ delimiter = self.multi_item_delimiter_text
1953
+ combined_items = delimiter.join(items_list)
1954
+ # Add final delimiter after the last item for logprob extraction
1955
+ single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
1956
+ batch_request.text = [single_prompt]
1957
+ else:
1958
+ # Single-item scoring: create separate prompts for each item
1959
+ if item_first:
1960
+ prompts = [f"{item}{query}" for item in items_list]
1961
+ else:
1962
+ prompts = [f"{query}{item}" for item in items_list]
1963
+ batch_request.text = prompts
2032
1964
 
2033
1965
  elif (
2034
1966
  isinstance(query, list)
@@ -2037,57 +1969,75 @@ class TokenizerManager:
2037
1969
  and isinstance(items[0], list)
2038
1970
  ):
2039
1971
  # Both query and items are token IDs
2040
- if item_first:
2041
- input_ids_list = [item + query for item in items]
1972
+ if use_multi_item_scoring:
1973
+ # Multi-item scoring: concatenate with delimiter token ID
1974
+ # Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
1975
+ delimiter_token_id = self.server_args.multi_item_scoring_delimiter
1976
+ combined_input_ids = self._build_multi_item_token_sequence(
1977
+ query, items, delimiter_token_id
1978
+ )
1979
+ batch_request.input_ids = [combined_input_ids]
2042
1980
  else:
2043
- input_ids_list = [query + item for item in items]
2044
-
2045
- batch_request.input_ids = input_ids_list
1981
+ # Single-item scoring: process each item separately
1982
+ if item_first:
1983
+ input_ids_list = [item + query for item in items]
1984
+ else:
1985
+ input_ids_list = [query + item for item in items]
1986
+ batch_request.input_ids = input_ids_list
2046
1987
  else:
2047
1988
  raise ValueError(
2048
1989
  "Invalid combination of query/items types for score_request."
2049
1990
  )
2050
1991
 
2051
1992
  results = await self.generate_request(batch_request, request).__anext__()
2052
- scores = []
2053
-
2054
- for result in results:
2055
- # Get logprobs for each token
2056
- logprobs = {}
2057
-
2058
- # For scoring requests, we read from output_token_ids_logprobs since we want
2059
- # the logprobs for specific tokens mentioned in the label_token_ids at
2060
- # the next position after the last token in the prompt
2061
- output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
2062
-
2063
- # Throw an error here if output_logprobs is None
2064
- if output_logprobs is None:
2065
- raise RuntimeError(
2066
- f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
2067
- "This usually indicates a problem with the scoring request or the backend output."
2068
- )
2069
-
2070
- for logprob, token_id, _ in output_logprobs[0]:
2071
- if token_id in label_token_ids:
2072
- logprobs[token_id] = logprob
2073
1993
 
2074
- # Get scores in order of label_token_ids
2075
- score_list = [
2076
- logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
2077
- ]
1994
+ if use_multi_item_scoring:
1995
+ # Multi-item scoring: extract scores from input_token_ids_logprobs
1996
+ return self._process_multi_item_scoring_results(
1997
+ results, items, label_token_ids, apply_softmax, batch_request
1998
+ )
1999
+ else:
2000
+ # Single-item scoring: process each result separately
2001
+ return self._process_single_item_scoring_results(
2002
+ results, label_token_ids, apply_softmax
2003
+ )
2078
2004
 
2079
- # Apply softmax to logprobs if needed
2080
- if apply_softmax:
2081
- score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
2082
- else:
2083
- # Convert logprobs to probabilities if not using softmax
2084
- score_list = [
2085
- math.exp(x) if x != float("-inf") else 0.0 for x in score_list
2086
- ]
2005
+ async def watch_load_thread(self):
2006
+ # Only for dp_controller when dp_size > 1
2007
+ if (
2008
+ self.server_args.dp_size == 1
2009
+ or self.server_args.load_balance_method == "round_robin"
2010
+ ):
2011
+ return
2087
2012
 
2088
- scores.append(score_list)
2013
+ while True:
2014
+ await asyncio.sleep(self.server_args.load_watch_interval)
2015
+ loads = await self.get_load_communicator(GetLoadReqInput())
2016
+ load_udpate_req = WatchLoadUpdateReq(loads=loads)
2017
+ self.send_to_scheduler.send_pyobj(load_udpate_req)
2089
2018
 
2090
- return scores
2019
+ def _trace_request_start(
2020
+ self,
2021
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
2022
+ created_time: Optional[float] = None,
2023
+ ):
2024
+ if obj.is_single:
2025
+ bootstrap_room = (
2026
+ obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
2027
+ )
2028
+ trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
2029
+ trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
2030
+ else:
2031
+ for i in range(len(obj.rid)):
2032
+ bootstrap_room = (
2033
+ obj.bootstrap_room[i]
2034
+ if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
2035
+ else None
2036
+ )
2037
+ trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
2038
+ trace_slice_start(
2039
+ "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
2040
+ )
2091
2041
 
2092
2042
 
2093
2043
  class ServerStatus(Enum):
@@ -2134,57 +2084,12 @@ class SignalHandler:
2134
2084
 
2135
2085
  def running_phase_sigquit_handler(self, signum=None, frame=None):
2136
2086
  logger.error(
2137
- "Received sigquit from a child process. It usually means the child failed."
2087
+ f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
2138
2088
  )
2139
2089
  self.tokenizer_manager.dump_requests_before_crash()
2140
2090
  kill_process_tree(os.getpid())
2141
2091
 
2142
2092
 
2143
- T = TypeVar("T")
2144
-
2145
-
2146
- class _Communicator(Generic[T]):
2147
- """Note: The communicator now only run up to 1 in-flight request at any time."""
2148
-
2149
- enable_multi_tokenizer = False
2150
-
2151
- def __init__(self, sender, fan_out: int):
2152
- self._sender = sender
2153
- self._fan_out = fan_out
2154
- self._result_event: Optional[asyncio.Event] = None
2155
- self._result_values: Optional[List[T]] = None
2156
- self._ready_queue: Deque[asyncio.Future] = deque()
2157
-
2158
- async def __call__(self, obj):
2159
- ready_event = asyncio.Event()
2160
- if self._result_event is not None or len(self._ready_queue) > 0:
2161
- self._ready_queue.append(ready_event)
2162
- await ready_event.wait()
2163
- assert self._result_event is None
2164
- assert self._result_values is None
2165
-
2166
- if obj:
2167
- if _Communicator.enable_multi_tokenizer:
2168
- obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
2169
- self._sender.send_pyobj(obj)
2170
-
2171
- self._result_event = asyncio.Event()
2172
- self._result_values = []
2173
- await self._result_event.wait()
2174
- result_values = self._result_values
2175
- self._result_event = self._result_values = None
2176
-
2177
- if len(self._ready_queue) > 0:
2178
- self._ready_queue.popleft().set()
2179
-
2180
- return result_values
2181
-
2182
- def handle_recv(self, recv_obj: T):
2183
- self._result_values.append(recv_obj)
2184
- if len(self._result_values) == self._fan_out:
2185
- self._result_event.set()
2186
-
2187
-
2188
2093
  # Note: request abort handling logic
2189
2094
  # We should handle all of the following cases correctly.
2190
2095
  #