sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
31
31
  from datetime import datetime
32
32
  from enum import Enum
33
33
  from http import HTTPStatus
34
- from typing import (
35
- Any,
36
- Awaitable,
37
- Deque,
38
- Dict,
39
- Generic,
40
- List,
41
- Optional,
42
- Tuple,
43
- TypeVar,
44
- Union,
45
- )
34
+ from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
46
35
 
47
36
  import fastapi
48
37
  import torch
@@ -53,80 +42,49 @@ from fastapi import BackgroundTasks
53
42
 
54
43
  from sglang.srt.aio_rwlock import RWLock
55
44
  from sglang.srt.configs.model_config import ModelConfig
56
- from sglang.srt.disaggregation.utils import (
57
- DisaggregationMode,
58
- KVClassType,
59
- TransferBackend,
60
- get_kv_class,
61
- )
62
- from sglang.srt.hf_transformers_utils import (
63
- get_processor,
64
- get_tokenizer,
65
- get_tokenizer_from_processor,
66
- )
67
- from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
45
+ from sglang.srt.disaggregation.utils import DisaggregationMode
46
+ from sglang.srt.lora.lora_registry import LoRARegistry
47
+ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
48
+ from sglang.srt.managers.disagg_service import start_disagg_service
68
49
  from sglang.srt.managers.io_struct import (
69
50
  AbortReq,
70
- BatchEmbeddingOut,
71
- BatchMultimodalOut,
72
- BatchStrOut,
73
- BatchTokenIDOut,
51
+ BatchEmbeddingOutput,
52
+ BatchMultimodalOutput,
53
+ BatchStrOutput,
54
+ BatchTokenIDOutput,
74
55
  BatchTokenizedEmbeddingReqInput,
75
56
  BatchTokenizedGenerateReqInput,
76
- ClearHiCacheReqInput,
77
- ClearHiCacheReqOutput,
78
- CloseSessionReqInput,
79
57
  ConfigureLoggingReq,
80
58
  EmbeddingReqInput,
81
- ExpertDistributionReq,
82
- ExpertDistributionReqOutput,
83
- FlushCacheReqInput,
84
- FlushCacheReqOutput,
85
59
  FreezeGCReq,
86
60
  GenerateReqInput,
87
- GetInternalStateReq,
88
- GetInternalStateReqOutput,
89
- GetWeightsByNameReqInput,
90
- GetWeightsByNameReqOutput,
61
+ GetLoadReqInput,
91
62
  HealthCheckOutput,
92
- InitWeightsUpdateGroupReqInput,
93
- InitWeightsUpdateGroupReqOutput,
94
- LoadLoRAAdapterReqInput,
95
- LoadLoRAAdapterReqOutput,
96
- LoRAUpdateResult,
97
- MultiTokenizerWarpper,
98
- OpenSessionReqInput,
63
+ MultiTokenizerWrapper,
99
64
  OpenSessionReqOutput,
100
- ProfileReq,
101
- ProfileReqOutput,
102
- ProfileReqType,
103
- ReleaseMemoryOccupationReqInput,
104
- ReleaseMemoryOccupationReqOutput,
105
- ResumeMemoryOccupationReqInput,
106
- ResumeMemoryOccupationReqOutput,
107
65
  SessionParams,
108
- SetInternalStateReq,
109
- SetInternalStateReqOutput,
110
- SlowDownReqInput,
111
- SlowDownReqOutput,
112
66
  TokenizedEmbeddingReqInput,
113
67
  TokenizedGenerateReqInput,
114
- UnloadLoRAAdapterReqInput,
115
- UnloadLoRAAdapterReqOutput,
116
68
  UpdateWeightFromDiskReqInput,
117
69
  UpdateWeightFromDiskReqOutput,
118
- UpdateWeightsFromDistributedReqInput,
119
- UpdateWeightsFromDistributedReqOutput,
120
- UpdateWeightsFromTensorReqInput,
121
- UpdateWeightsFromTensorReqOutput,
70
+ WatchLoadUpdateReq,
122
71
  )
123
72
  from sglang.srt.managers.mm_utils import TensorTransportMode
124
73
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
125
74
  from sglang.srt.managers.scheduler import is_health_check_generate_req
126
75
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
76
+ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
127
77
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
128
78
  from sglang.srt.sampling.sampling_params import SamplingParams
129
79
  from sglang.srt.server_args import PortArgs, ServerArgs
80
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
81
+ from sglang.srt.tracing.trace import (
82
+ trace_get_proc_propagate_context,
83
+ trace_req_finish,
84
+ trace_req_start,
85
+ trace_slice_end,
86
+ trace_slice_start,
87
+ )
130
88
  from sglang.srt.utils import (
131
89
  configure_gc_warning,
132
90
  dataclass_to_string_truncated,
@@ -136,6 +94,11 @@ from sglang.srt.utils import (
136
94
  get_zmq_socket,
137
95
  kill_process_tree,
138
96
  )
97
+ from sglang.srt.utils.hf_transformers_utils import (
98
+ get_processor,
99
+ get_tokenizer,
100
+ get_tokenizer_from_processor,
101
+ )
139
102
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
140
103
 
141
104
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -180,7 +143,7 @@ class ReqState:
180
143
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
181
144
 
182
145
 
183
- class TokenizerManager:
146
+ class TokenizerManager(TokenizerCommunicatorMixin):
184
147
  """TokenizerManager is a process that tokenizes the text."""
185
148
 
186
149
  def __init__(
@@ -199,6 +162,7 @@ class TokenizerManager:
199
162
  else None
200
163
  )
201
164
  self.crash_dump_folder = server_args.crash_dump_folder
165
+ self.enable_trace = server_args.enable_trace
202
166
 
203
167
  # Read model args
204
168
  self.model_path = server_args.model_path
@@ -210,8 +174,17 @@ class TokenizerManager:
210
174
  self.image_token_id = self.model_config.image_token_id
211
175
  self.max_req_input_len = None # Will be set later in engine.py
212
176
 
177
+ speculative_algorithm = SpeculativeAlgorithm.from_string(
178
+ server_args.speculative_algorithm
179
+ )
180
+ self.reserve_input_token_num = (
181
+ 0
182
+ if speculative_algorithm.is_none()
183
+ else server_args.speculative_num_draft_tokens
184
+ )
185
+
213
186
  if self.model_config.is_multimodal:
214
- import_processors()
187
+ import_processors("sglang.srt.multimodal.processors")
215
188
  try:
216
189
  _processor = get_processor(
217
190
  server_args.tokenizer_path,
@@ -262,6 +235,18 @@ class TokenizerManager:
262
235
  trust_remote_code=server_args.trust_remote_code,
263
236
  revision=server_args.revision,
264
237
  )
238
+ # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
239
+ if (
240
+ server_args.enable_dynamic_batch_tokenizer
241
+ and not server_args.skip_tokenizer_init
242
+ ):
243
+ self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
244
+ self.tokenizer,
245
+ max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
246
+ batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
247
+ )
248
+ else:
249
+ self.async_dynamic_batch_tokenizer = None
265
250
 
266
251
  # Init inter-process communication
267
252
  context = zmq.asyncio.Context(2)
@@ -319,8 +304,10 @@ class TokenizerManager:
319
304
  # LoRA updates and inference to overlap.
320
305
  self.lora_update_lock = asyncio.Lock()
321
306
 
322
- # For PD disaggregtion
323
- self.init_disaggregation()
307
+ self.disaggregation_mode = DisaggregationMode(
308
+ self.server_args.disaggregation_mode
309
+ )
310
+ self.bootstrap_server = start_disagg_service(self.server_args)
324
311
 
325
312
  # For load balancing
326
313
  self.current_load = 0
@@ -328,11 +315,16 @@ class TokenizerManager:
328
315
 
329
316
  # Metrics
330
317
  if self.enable_metrics:
318
+ labels = {
319
+ "model_name": self.server_args.served_model_name,
320
+ # TODO: Add lora name/path in the future,
321
+ }
322
+ if server_args.tokenizer_metrics_allowed_custom_labels:
323
+ for label in server_args.tokenizer_metrics_allowed_custom_labels:
324
+ labels[label] = ""
331
325
  self.metrics_collector = TokenizerMetricsCollector(
332
- labels={
333
- "model_name": self.server_args.served_model_name,
334
- # TODO: Add lora name/path in the future,
335
- },
326
+ server_args=server_args,
327
+ labels=labels,
336
328
  bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
337
329
  bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
338
330
  bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
@@ -343,58 +335,14 @@ class TokenizerManager:
343
335
  if self.server_args.gc_warning_threshold_secs > 0.0:
344
336
  configure_gc_warning(self.server_args.gc_warning_threshold_secs)
345
337
 
346
- # Communicators
347
- self.init_weights_update_group_communicator = _Communicator(
348
- self.send_to_scheduler, server_args.dp_size
349
- )
350
- self.update_weights_from_distributed_communicator = _Communicator(
351
- self.send_to_scheduler, server_args.dp_size
352
- )
353
- self.update_weights_from_tensor_communicator = _Communicator(
354
- self.send_to_scheduler, server_args.dp_size
355
- )
356
- self.get_weights_by_name_communicator = _Communicator(
357
- self.send_to_scheduler, server_args.dp_size
358
- )
359
- self.release_memory_occupation_communicator = _Communicator(
360
- self.send_to_scheduler, server_args.dp_size
361
- )
362
- self.resume_memory_occupation_communicator = _Communicator(
363
- self.send_to_scheduler, server_args.dp_size
364
- )
365
- self.slow_down_communicator = _Communicator(
366
- self.send_to_scheduler, server_args.dp_size
367
- )
368
- self.flush_cache_communicator = _Communicator(
369
- self.send_to_scheduler, server_args.dp_size
370
- )
371
- self.clear_hicache_storage_communicator = _Communicator(
372
- self.send_to_scheduler, server_args.dp_size
373
- )
374
- self.profile_communicator = _Communicator(
375
- self.send_to_scheduler, server_args.dp_size
376
- )
377
- self.get_internal_state_communicator = _Communicator(
378
- self.send_to_scheduler, server_args.dp_size
379
- )
380
- self.set_internal_state_communicator = _Communicator(
381
- self.send_to_scheduler, server_args.dp_size
382
- )
383
- self.expert_distribution_communicator = _Communicator(
384
- self.send_to_scheduler, server_args.dp_size
385
- )
386
- self.update_lora_adapter_communicator = _Communicator(
387
- self.send_to_scheduler, server_args.dp_size
388
- )
389
-
390
338
  self._result_dispatcher = TypeBasedDispatcher(
391
339
  [
392
340
  (
393
341
  (
394
- BatchStrOut,
395
- BatchEmbeddingOut,
396
- BatchTokenIDOut,
397
- BatchMultimodalOut,
342
+ BatchStrOutput,
343
+ BatchEmbeddingOutput,
344
+ BatchTokenIDOutput,
345
+ BatchMultimodalOutput,
398
346
  ),
399
347
  self._handle_batch_output,
400
348
  ),
@@ -404,100 +352,15 @@ class TokenizerManager:
404
352
  UpdateWeightFromDiskReqOutput,
405
353
  self._handle_update_weights_from_disk_req_output,
406
354
  ),
407
- (
408
- InitWeightsUpdateGroupReqOutput,
409
- self.init_weights_update_group_communicator.handle_recv,
410
- ),
411
- (
412
- UpdateWeightsFromDistributedReqOutput,
413
- self.update_weights_from_distributed_communicator.handle_recv,
414
- ),
415
- (
416
- UpdateWeightsFromTensorReqOutput,
417
- self.update_weights_from_tensor_communicator.handle_recv,
418
- ),
419
- (
420
- GetWeightsByNameReqOutput,
421
- self.get_weights_by_name_communicator.handle_recv,
422
- ),
423
- (
424
- ReleaseMemoryOccupationReqOutput,
425
- self.release_memory_occupation_communicator.handle_recv,
426
- ),
427
- (
428
- ResumeMemoryOccupationReqOutput,
429
- self.resume_memory_occupation_communicator.handle_recv,
430
- ),
431
- (
432
- SlowDownReqOutput,
433
- self.slow_down_communicator.handle_recv,
434
- ),
435
- (
436
- ClearHiCacheReqOutput,
437
- self.clear_hicache_storage_communicator.handle_recv,
438
- ),
439
- (
440
- FlushCacheReqOutput,
441
- self.flush_cache_communicator.handle_recv,
442
- ),
443
- (
444
- ProfileReqOutput,
445
- self.profile_communicator.handle_recv,
446
- ),
447
355
  (
448
356
  FreezeGCReq,
449
357
  lambda x: None,
450
358
  ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
451
- (
452
- GetInternalStateReqOutput,
453
- self.get_internal_state_communicator.handle_recv,
454
- ),
455
- (
456
- SetInternalStateReqOutput,
457
- self.set_internal_state_communicator.handle_recv,
458
- ),
459
- (
460
- ExpertDistributionReqOutput,
461
- self.expert_distribution_communicator.handle_recv,
462
- ),
463
- (
464
- LoRAUpdateResult,
465
- self.update_lora_adapter_communicator.handle_recv,
466
- ),
467
359
  (HealthCheckOutput, lambda x: None),
468
360
  ]
469
361
  )
470
362
 
471
- def init_disaggregation(self):
472
- self.disaggregation_mode = DisaggregationMode(
473
- self.server_args.disaggregation_mode
474
- )
475
- self.disaggregation_transfer_backend = TransferBackend(
476
- self.server_args.disaggregation_transfer_backend
477
- )
478
- # Start kv boostrap server on prefill
479
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
480
- # only start bootstrap server on prefill tm
481
- kv_bootstrap_server_class = get_kv_class(
482
- self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
483
- )
484
- self.bootstrap_server = kv_bootstrap_server_class(
485
- self.server_args.disaggregation_bootstrap_port
486
- )
487
- is_create_store = (
488
- self.server_args.node_rank == 0
489
- and self.server_args.disaggregation_transfer_backend == "ascend"
490
- )
491
- if is_create_store:
492
- try:
493
- from mf_adapter import create_config_store
494
-
495
- ascend_url = os.getenv("ASCEND_MF_STORE_URL")
496
- create_config_store(ascend_url)
497
- except Exception as e:
498
- error_message = f"Failed create mf store, invalid ascend_url."
499
- error_message += f" With exception {e}"
500
- raise error_message
363
+ self.init_communicators(server_args)
501
364
 
502
365
  async def generate_request(
503
366
  self,
@@ -517,6 +380,9 @@ class TokenizerManager:
517
380
  # If it's a single value, add worker_id prefix
518
381
  obj.rid = f"{self.worker_id}_{obj.rid}"
519
382
 
383
+ if self.enable_trace:
384
+ self._trace_request_start(obj, created_time)
385
+
520
386
  if self.log_requests:
521
387
  max_length, skip_names, _ = self.log_request_metadata
522
388
  logger.info(
@@ -542,6 +408,144 @@ class TokenizerManager:
542
408
  ):
543
409
  yield response
544
410
 
411
+ def _detect_input_format(
412
+ self, texts: Union[str, List[str]], is_cross_encoder: bool
413
+ ) -> str:
414
+ """Detect the format of input texts for proper tokenization handling.
415
+
416
+ Returns:
417
+ - "single_string": Regular single text like "Hello world"
418
+ - "batch_strings": Regular batch like ["Hello", "World"]
419
+ - "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
420
+ """
421
+ if isinstance(texts, str):
422
+ return "single_string"
423
+
424
+ if (
425
+ is_cross_encoder
426
+ and len(texts) > 0
427
+ and isinstance(texts[0], list)
428
+ and len(texts[0]) == 2
429
+ ):
430
+ return "cross_encoder_pairs"
431
+
432
+ return "batch_strings"
433
+
434
+ def _prepare_tokenizer_input(
435
+ self, texts: Union[str, List[str]], input_format: str
436
+ ) -> Union[List[str], List[List[str]]]:
437
+ """Prepare input for the tokenizer based on detected format."""
438
+ if input_format == "single_string":
439
+ return [texts] # Wrap single string for batch processing
440
+ elif input_format == "cross_encoder_pairs":
441
+ return texts # Already in correct format: [["query", "doc"]]
442
+ else: # batch_strings
443
+ return texts # Already in correct format: ["text1", "text2"]
444
+
445
+ def _extract_tokenizer_results(
446
+ self,
447
+ input_ids: List[List[int]],
448
+ token_type_ids: Optional[List[List[int]]],
449
+ input_format: str,
450
+ original_batch_size: int,
451
+ ) -> Union[
452
+ Tuple[List[int], Optional[List[int]]],
453
+ Tuple[List[List[int]], Optional[List[List[int]]]],
454
+ ]:
455
+ """Extract results from tokenizer output based on input format."""
456
+
457
+ # For single inputs (string or single cross-encoder pair), extract first element
458
+ if (
459
+ input_format in ["single_string", "cross_encoder_pairs"]
460
+ and original_batch_size == 1
461
+ ):
462
+ single_input_ids = input_ids[0] if input_ids else []
463
+ single_token_type_ids = token_type_ids[0] if token_type_ids else None
464
+ return single_input_ids, single_token_type_ids
465
+
466
+ # For true batches, return as-is
467
+ return input_ids, token_type_ids
468
+
469
+ async def _tokenize_texts(
470
+ self, texts: Union[str, List[str]], is_cross_encoder: bool = False
471
+ ) -> Union[
472
+ Tuple[List[int], Optional[List[int]]],
473
+ Tuple[List[List[int]], Optional[List[List[int]]]],
474
+ ]:
475
+ """
476
+ Tokenize text(s) using the appropriate tokenizer strategy.
477
+
478
+ This method handles multiple input formats and chooses between async dynamic
479
+ batch tokenizer (for single texts only) and regular tokenizer.
480
+
481
+ Args:
482
+ texts: Text input in various formats:
483
+
484
+ Regular cases:
485
+ - Single string: "How are you?"
486
+ - Batch of strings: ["Hello", "World", "How are you?"]
487
+
488
+ Cross-encoder cases (sentence pairs for similarity/ranking):
489
+ - Single pair: [["query text", "document text"]]
490
+ - Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
491
+
492
+ is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
493
+ Enables proper handling of sentence pairs with segment IDs.
494
+
495
+ Returns:
496
+ Single input cases:
497
+ Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
498
+ Example: ([101, 2129, 102], [0, 0, 0]) for single text
499
+ Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
500
+
501
+ Batch input cases:
502
+ Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
503
+ Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
504
+
505
+ Note: token_type_ids is None unless is_cross_encoder=True.
506
+ """
507
+ if not texts or self.tokenizer is None:
508
+ raise ValueError("texts cannot be empty and tokenizer must be initialized")
509
+
510
+ # Step 1: Detect input format and prepare for tokenization
511
+ input_format = self._detect_input_format(texts, is_cross_encoder)
512
+ tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
513
+ original_batch_size = len(texts) if not isinstance(texts, str) else 1
514
+
515
+ # Step 2: Set up tokenizer arguments
516
+ tokenizer_kwargs = (
517
+ {"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
518
+ )
519
+
520
+ # Step 3: Choose tokenization strategy
521
+ use_async_tokenizer = (
522
+ self.async_dynamic_batch_tokenizer is not None
523
+ and input_format == "single_string"
524
+ )
525
+
526
+ if use_async_tokenizer:
527
+ logger.debug("Using async dynamic batch tokenizer for single text")
528
+ result = await self.async_dynamic_batch_tokenizer.encode(
529
+ tokenizer_input[0], **tokenizer_kwargs
530
+ )
531
+ # Convert to batch format for consistency
532
+ input_ids = [result["input_ids"]]
533
+ token_type_ids = (
534
+ [result["token_type_ids"]]
535
+ if is_cross_encoder and result.get("token_type_ids")
536
+ else None
537
+ )
538
+ else:
539
+ logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
540
+ encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
541
+ input_ids = encoded["input_ids"]
542
+ token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
543
+
544
+ # Step 4: Extract results based on input format
545
+ return self._extract_tokenizer_results(
546
+ input_ids, token_type_ids, input_format, original_batch_size
547
+ )
548
+
545
549
  async def _tokenize_one_request(
546
550
  self,
547
551
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -572,14 +576,10 @@ class TokenizerManager:
572
576
  "accept text prompts. Please provide input_ids or re-initialize "
573
577
  "the engine with skip_tokenizer_init=False."
574
578
  )
575
- encoded = self.tokenizer(
576
- input_text, return_token_type_ids=is_cross_encoder_request
577
- )
578
579
 
579
- input_ids = encoded["input_ids"]
580
- if is_cross_encoder_request:
581
- input_ids = encoded["input_ids"][0]
582
- token_type_ids = encoded.get("token_type_ids", [None])[0]
580
+ input_ids, token_type_ids = await self._tokenize_texts(
581
+ input_text, is_cross_encoder_request
582
+ )
583
583
 
584
584
  if self.mm_processor and obj.contains_mm_input():
585
585
  if not isinstance(obj.image_data, list):
@@ -599,6 +599,7 @@ class TokenizerManager:
599
599
  mm_inputs = None
600
600
 
601
601
  self._validate_one_request(obj, input_ids)
602
+ trace_slice_end("tokenize", obj.rid)
602
603
  return self._create_tokenized_object(
603
604
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
604
605
  )
@@ -611,6 +612,7 @@ class TokenizerManager:
611
612
  _max_req_len = self.context_len
612
613
 
613
614
  input_token_num = len(input_ids) if input_ids is not None else 0
615
+ input_token_num += self.reserve_input_token_num
614
616
  if input_token_num >= self.context_len:
615
617
  if self.server_args.allow_auto_truncate:
616
618
  logger.warning(
@@ -673,7 +675,7 @@ class TokenizerManager:
673
675
  ):
674
676
  raise ValueError(
675
677
  "The server is not configured to enable custom logit processor. "
676
- "Please set `--enable-custom-logits-processor` to enable this feature."
678
+ "Please set `--enable-custom-logit-processor` to enable this feature."
677
679
  )
678
680
 
679
681
  def _validate_input_ids_in_vocab(
@@ -712,7 +714,6 @@ class TokenizerManager:
712
714
  )
713
715
 
714
716
  tokenized_obj = TokenizedGenerateReqInput(
715
- obj.rid,
716
717
  input_text,
717
718
  input_ids,
718
719
  mm_inputs,
@@ -722,6 +723,7 @@ class TokenizerManager:
722
723
  obj.top_logprobs_num,
723
724
  obj.token_ids_logprob,
724
725
  obj.stream,
726
+ rid=obj.rid,
725
727
  bootstrap_host=obj.bootstrap_host,
726
728
  bootstrap_port=obj.bootstrap_port,
727
729
  bootstrap_room=obj.bootstrap_room,
@@ -731,15 +733,18 @@ class TokenizerManager:
731
733
  custom_logit_processor=obj.custom_logit_processor,
732
734
  return_hidden_states=obj.return_hidden_states,
733
735
  data_parallel_rank=obj.data_parallel_rank,
736
+ priority=obj.priority,
737
+ extra_key=obj.extra_key,
734
738
  )
735
739
  elif isinstance(obj, EmbeddingReqInput):
736
740
  tokenized_obj = TokenizedEmbeddingReqInput(
737
- obj.rid,
738
741
  input_text,
739
742
  input_ids,
740
743
  mm_inputs,
741
744
  token_type_ids,
742
745
  sampling_params,
746
+ rid=obj.rid,
747
+ priority=obj.priority,
743
748
  )
744
749
 
745
750
  return tokenized_obj
@@ -754,19 +759,30 @@ class TokenizerManager:
754
759
  requests = [obj[i] for i in range(batch_size)]
755
760
  texts = [req.text for req in requests]
756
761
 
757
- # Batch tokenize all texts
758
- encoded = self.tokenizer(texts)
759
- input_ids_list = encoded["input_ids"]
762
+ # Check if any request is a cross-encoder request
763
+ is_cross_encoder_request = any(
764
+ isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
765
+ for req in requests
766
+ )
767
+
768
+ # Batch tokenize all texts using unified method
769
+ input_ids_list, token_type_ids_list = await self._tokenize_texts(
770
+ texts, is_cross_encoder_request
771
+ )
760
772
 
761
773
  # Process all requests
762
774
  tokenized_objs = []
763
775
  for i, req in enumerate(requests):
764
776
  self._validate_one_request(obj[i], input_ids_list[i])
777
+ token_type_ids = (
778
+ token_type_ids_list[i] if token_type_ids_list is not None else None
779
+ )
765
780
  tokenized_objs.append(
766
781
  self._create_tokenized_object(
767
- req, req.text, input_ids_list[i], None, None
782
+ req, req.text, input_ids_list[i], None, None, token_type_ids
768
783
  )
769
784
  )
785
+ trace_slice_end("tokenize", req.rid)
770
786
  logger.debug(f"Completed batch processing for {batch_size} requests")
771
787
  return tokenized_objs
772
788
 
@@ -794,9 +810,12 @@ class TokenizerManager:
794
810
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
795
811
  created_time: Optional[float] = None,
796
812
  ):
813
+ trace_slice_start("dispatch", obj.rid)
814
+ tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
797
815
  self.send_to_scheduler.send_pyobj(tokenized_obj)
798
816
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
799
817
  self.rid_to_state[obj.rid] = state
818
+ trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
800
819
  return state
801
820
 
802
821
  def _send_batch_request(
@@ -1014,73 +1033,16 @@ class TokenizerManager:
1014
1033
  except StopAsyncIteration:
1015
1034
  pass
1016
1035
 
1017
- async def flush_cache(self) -> FlushCacheReqOutput:
1018
- return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
1019
-
1020
- async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
1021
- """Clear the hierarchical cache storage."""
1022
- # Delegate to the scheduler to handle HiCacheStorage clearing
1023
- return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
1024
- 0
1025
- ]
1026
-
1027
1036
  def abort_request(self, rid: str = "", abort_all: bool = False):
1028
1037
  if not abort_all and rid not in self.rid_to_state:
1029
1038
  return
1030
- req = AbortReq(rid, abort_all)
1039
+ req = AbortReq(rid=rid, abort_all=abort_all)
1031
1040
  self.send_to_scheduler.send_pyobj(req)
1032
-
1033
1041
  if self.enable_metrics:
1034
- self.metrics_collector.observe_one_aborted_request()
1035
-
1036
- async def start_profile(
1037
- self,
1038
- output_dir: Optional[str] = None,
1039
- start_step: Optional[int] = None,
1040
- num_steps: Optional[int] = None,
1041
- activities: Optional[List[str]] = None,
1042
- with_stack: Optional[bool] = None,
1043
- record_shapes: Optional[bool] = None,
1044
- profile_by_stage: bool = False,
1045
- ):
1046
- self.auto_create_handle_loop()
1047
- env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
1048
- with_stack = False if with_stack is False or env_with_stack is False else True
1049
- req = ProfileReq(
1050
- type=ProfileReqType.START_PROFILE,
1051
- output_dir=output_dir,
1052
- start_step=start_step,
1053
- num_steps=num_steps,
1054
- activities=activities,
1055
- with_stack=with_stack,
1056
- record_shapes=record_shapes,
1057
- profile_by_stage=profile_by_stage,
1058
- profile_id=str(time.time()),
1059
- )
1060
- return await self._execute_profile(req)
1061
-
1062
- async def stop_profile(self):
1063
- self.auto_create_handle_loop()
1064
- req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
1065
- return await self._execute_profile(req)
1066
-
1067
- async def _execute_profile(self, req: ProfileReq):
1068
- result = (await self.profile_communicator(req))[0]
1069
- if not result.success:
1070
- raise RuntimeError(result.message)
1071
- return result
1072
-
1073
- async def start_expert_distribution_record(self):
1074
- self.auto_create_handle_loop()
1075
- await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
1076
-
1077
- async def stop_expert_distribution_record(self):
1078
- self.auto_create_handle_loop()
1079
- await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
1080
-
1081
- async def dump_expert_distribution_record(self):
1082
- self.auto_create_handle_loop()
1083
- await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
1042
+ # TODO: also use custom_labels from the request
1043
+ self.metrics_collector.observe_one_aborted_request(
1044
+ self.metrics_collector.labels
1045
+ )
1084
1046
 
1085
1047
  async def pause_generation(self):
1086
1048
  async with self.is_pause_cond:
@@ -1117,7 +1079,7 @@ class TokenizerManager:
1117
1079
  self, obj: UpdateWeightFromDiskReqInput
1118
1080
  ) -> Tuple[bool, str]:
1119
1081
  if self.server_args.tokenizer_worker_num > 1:
1120
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1082
+ obj = MultiTokenizerWrapper(self.worker_id, obj)
1121
1083
  self.send_to_scheduler.send_pyobj(obj)
1122
1084
  self.model_update_result = asyncio.Future()
1123
1085
  if self.server_args.dp_size == 1:
@@ -1142,291 +1104,6 @@ class TokenizerManager:
1142
1104
  all_paused_requests = [r.num_paused_requests for r in result]
1143
1105
  return all_success, all_message, all_paused_requests
1144
1106
 
1145
- async def init_weights_update_group(
1146
- self,
1147
- obj: InitWeightsUpdateGroupReqInput,
1148
- request: Optional[fastapi.Request] = None,
1149
- ) -> Tuple[bool, str]:
1150
- self.auto_create_handle_loop()
1151
- assert (
1152
- self.server_args.dp_size == 1
1153
- ), "dp_size must be 1 for init parameter update group"
1154
- result = (await self.init_weights_update_group_communicator(obj))[0]
1155
- return result.success, result.message
1156
-
1157
- async def update_weights_from_distributed(
1158
- self,
1159
- obj: UpdateWeightsFromDistributedReqInput,
1160
- request: Optional[fastapi.Request] = None,
1161
- ) -> Tuple[bool, str]:
1162
- self.auto_create_handle_loop()
1163
- assert (
1164
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1165
- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
1166
-
1167
- if obj.abort_all_requests:
1168
- self.abort_request(abort_all=True)
1169
-
1170
- # This means that weight sync
1171
- # cannot run while requests are in progress.
1172
- async with self.model_update_lock.writer_lock:
1173
- result = (await self.update_weights_from_distributed_communicator(obj))[0]
1174
- return result.success, result.message
1175
-
1176
- async def update_weights_from_tensor(
1177
- self,
1178
- obj: UpdateWeightsFromTensorReqInput,
1179
- request: Optional[fastapi.Request] = None,
1180
- ) -> Tuple[bool, str]:
1181
- self.auto_create_handle_loop()
1182
- assert (
1183
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1184
- ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
1185
-
1186
- if obj.abort_all_requests:
1187
- self.abort_request(abort_all=True)
1188
-
1189
- # This means that weight sync
1190
- # cannot run while requests are in progress.
1191
- async with self.model_update_lock.writer_lock:
1192
- result = (await self.update_weights_from_tensor_communicator(obj))[0]
1193
- return result.success, result.message
1194
-
1195
- async def load_lora_adapter(
1196
- self,
1197
- obj: LoadLoRAAdapterReqInput,
1198
- _: Optional[fastapi.Request] = None,
1199
- ) -> LoadLoRAAdapterReqOutput:
1200
- self.auto_create_handle_loop()
1201
-
1202
- try:
1203
- if not self.server_args.enable_lora:
1204
- raise ValueError(
1205
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1206
- )
1207
-
1208
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1209
- # with dp_size > 1.
1210
- assert (
1211
- self.server_args.dp_size == 1
1212
- ), "dp_size must be 1 for dynamic lora loading"
1213
- logger.info(
1214
- "Start load Lora adapter. Lora name=%s, path=%s",
1215
- obj.lora_name,
1216
- obj.lora_path,
1217
- )
1218
-
1219
- async with self.lora_update_lock:
1220
- if (
1221
- self.server_args.max_loaded_loras is not None
1222
- and self.lora_registry.num_registered_loras
1223
- >= self.server_args.max_loaded_loras
1224
- ):
1225
- raise ValueError(
1226
- f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1227
- f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1228
- "Please unload some LoRA adapters before loading new ones."
1229
- )
1230
-
1231
- # Generate new uniquely identifiable LoRARef object.
1232
- new_adapter = LoRARef(
1233
- lora_name=obj.lora_name,
1234
- lora_path=obj.lora_path,
1235
- pinned=obj.pinned,
1236
- )
1237
-
1238
- # Trigger the actual loading operation at the backend processes.
1239
- obj.lora_id = new_adapter.lora_id
1240
- result = (await self.update_lora_adapter_communicator(obj))[0]
1241
-
1242
- # Register the LoRA adapter only after loading is successful.
1243
- if result.success:
1244
- await self.lora_registry.register(new_adapter)
1245
-
1246
- return result
1247
- except ValueError as e:
1248
- return LoadLoRAAdapterReqOutput(
1249
- success=False,
1250
- error_message=str(e),
1251
- )
1252
-
1253
- async def unload_lora_adapter(
1254
- self,
1255
- obj: UnloadLoRAAdapterReqInput,
1256
- _: Optional[fastapi.Request] = None,
1257
- ) -> UnloadLoRAAdapterReqOutput:
1258
- self.auto_create_handle_loop()
1259
-
1260
- try:
1261
- if not self.server_args.enable_lora:
1262
- raise ValueError(
1263
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1264
- )
1265
-
1266
- assert (
1267
- obj.lora_name is not None
1268
- ), "lora_name must be provided to unload LoRA adapter"
1269
-
1270
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1271
- # with dp_size > 1.
1272
- assert (
1273
- self.server_args.dp_size == 1
1274
- ), "dp_size must be 1 for dynamic lora loading"
1275
- logger.info(
1276
- "Start unload Lora adapter. Lora name=%s",
1277
- obj.lora_name,
1278
- )
1279
-
1280
- async with self.lora_update_lock:
1281
- # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1282
- # from being started.
1283
- lora_id = await self.lora_registry.unregister(obj.lora_name)
1284
- obj.lora_id = lora_id
1285
-
1286
- # Initiate the actual unloading operation at the backend processes only after all
1287
- # ongoing requests using this LoRA adapter are finished.
1288
- await self.lora_registry.wait_for_unload(lora_id)
1289
- result = (await self.update_lora_adapter_communicator(obj))[0]
1290
-
1291
- return result
1292
- except ValueError as e:
1293
- return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1294
-
1295
- async def get_weights_by_name(
1296
- self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
1297
- ):
1298
- self.auto_create_handle_loop()
1299
- results = await self.get_weights_by_name_communicator(obj)
1300
- all_parameters = [r.parameter for r in results]
1301
- if self.server_args.dp_size == 1:
1302
- return all_parameters[0]
1303
- else:
1304
- return all_parameters
1305
-
1306
- async def release_memory_occupation(
1307
- self,
1308
- obj: ReleaseMemoryOccupationReqInput,
1309
- request: Optional[fastapi.Request] = None,
1310
- ):
1311
- self.auto_create_handle_loop()
1312
- await self.release_memory_occupation_communicator(obj)
1313
-
1314
- async def resume_memory_occupation(
1315
- self,
1316
- obj: ResumeMemoryOccupationReqInput,
1317
- request: Optional[fastapi.Request] = None,
1318
- ):
1319
- self.auto_create_handle_loop()
1320
- await self.resume_memory_occupation_communicator(obj)
1321
-
1322
- async def slow_down(
1323
- self,
1324
- obj: SlowDownReqInput,
1325
- request: Optional[fastapi.Request] = None,
1326
- ):
1327
- self.auto_create_handle_loop()
1328
- await self.slow_down_communicator(obj)
1329
-
1330
- async def open_session(
1331
- self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
1332
- ):
1333
- self.auto_create_handle_loop()
1334
-
1335
- if obj.session_id is None:
1336
- obj.session_id = uuid.uuid4().hex
1337
- elif obj.session_id in self.session_futures:
1338
- return None
1339
-
1340
- if self.server_args.tokenizer_worker_num > 1:
1341
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1342
- self.send_to_scheduler.send_pyobj(obj)
1343
-
1344
- self.session_futures[obj.session_id] = asyncio.Future()
1345
- session_id = await self.session_futures[obj.session_id]
1346
- del self.session_futures[obj.session_id]
1347
- return session_id
1348
-
1349
- async def close_session(
1350
- self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
1351
- ):
1352
- await self.send_to_scheduler.send_pyobj(obj)
1353
-
1354
- async def get_internal_state(self) -> List[Dict[Any, Any]]:
1355
- req = GetInternalStateReq()
1356
- responses: List[GetInternalStateReqOutput] = (
1357
- await self.get_internal_state_communicator(req)
1358
- )
1359
- # Many DP ranks
1360
- return [res.internal_state for res in responses]
1361
-
1362
- async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
1363
- responses: List[SetInternalStateReqOutput] = (
1364
- await self.set_internal_state_communicator(obj)
1365
- )
1366
- return [res.updated for res in responses]
1367
-
1368
- async def get_load(self) -> dict:
1369
- # TODO(lsyin): fake load report server
1370
- if not self.current_load_lock.locked():
1371
- async with self.current_load_lock:
1372
- internal_state = await self.get_internal_state()
1373
- self.current_load = internal_state[0]["load"]
1374
- return {"load": self.current_load}
1375
-
1376
- def get_log_request_metadata(self):
1377
- max_length = None
1378
- skip_names = None
1379
- out_skip_names = None
1380
- if self.log_requests:
1381
- if self.log_requests_level == 0:
1382
- max_length = 1 << 30
1383
- skip_names = set(
1384
- [
1385
- "text",
1386
- "input_ids",
1387
- "input_embeds",
1388
- "image_data",
1389
- "audio_data",
1390
- "lora_path",
1391
- "sampling_params",
1392
- ]
1393
- )
1394
- out_skip_names = set(
1395
- [
1396
- "text",
1397
- "output_ids",
1398
- "embedding",
1399
- ]
1400
- )
1401
- elif self.log_requests_level == 1:
1402
- max_length = 1 << 30
1403
- skip_names = set(
1404
- [
1405
- "text",
1406
- "input_ids",
1407
- "input_embeds",
1408
- "image_data",
1409
- "audio_data",
1410
- "lora_path",
1411
- ]
1412
- )
1413
- out_skip_names = set(
1414
- [
1415
- "text",
1416
- "output_ids",
1417
- "embedding",
1418
- ]
1419
- )
1420
- elif self.log_requests_level == 2:
1421
- max_length = 2048
1422
- elif self.log_requests_level == 3:
1423
- max_length = 1 << 30
1424
- else:
1425
- raise ValueError(
1426
- f"Invalid --log-requests-level: {self.log_requests_level=}"
1427
- )
1428
- return max_length, skip_names, out_skip_names
1429
-
1430
1107
  def configure_logging(self, obj: ConfigureLoggingReq):
1431
1108
  if obj.log_requests is not None:
1432
1109
  self.log_requests = obj.log_requests
@@ -1491,6 +1168,9 @@ class TokenizerManager:
1491
1168
  self.asyncio_tasks.add(
1492
1169
  loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
1493
1170
  )
1171
+ self.asyncio_tasks.add(
1172
+ loop.create_task(print_exception_wrapper(self.watch_load_thread))
1173
+ )
1494
1174
 
1495
1175
  def dump_requests_before_crash(self):
1496
1176
  if self.crash_dump_performed:
@@ -1582,12 +1262,12 @@ class TokenizerManager:
1582
1262
  # Drain requests
1583
1263
  while True:
1584
1264
  remain_num_req = len(self.rid_to_state)
1265
+ remaining_rids = list(self.rid_to_state.keys())
1585
1266
 
1586
1267
  if self.server_status == ServerStatus.UnHealthy:
1587
1268
  # if health check failed, we should exit immediately
1588
1269
  logger.error(
1589
- "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1590
- remain_num_req,
1270
+ "Signal SIGTERM received while health check failed. Force exiting."
1591
1271
  )
1592
1272
  self.dump_requests_before_crash()
1593
1273
  break
@@ -1595,13 +1275,12 @@ class TokenizerManager:
1595
1275
  elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
1596
1276
  # if force shutdown flag set, exit immediately
1597
1277
  logger.error(
1598
- "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
1599
- remain_num_req,
1278
+ "Signal SIGTERM received while force shutdown flag set. Force exiting."
1600
1279
  )
1601
1280
  break
1602
1281
 
1603
1282
  logger.info(
1604
- f"Gracefully exiting... remaining number of requests {remain_num_req}"
1283
+ f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
1605
1284
  )
1606
1285
  if remain_num_req > 0:
1607
1286
  await asyncio.sleep(5)
@@ -1622,7 +1301,10 @@ class TokenizerManager:
1622
1301
  def _handle_batch_output(
1623
1302
  self,
1624
1303
  recv_obj: Union[
1625
- BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
1304
+ BatchStrOutput,
1305
+ BatchEmbeddingOutput,
1306
+ BatchMultimodalOutput,
1307
+ BatchTokenIDOutput,
1626
1308
  ],
1627
1309
  ):
1628
1310
  for i, rid in enumerate(recv_obj.rids):
@@ -1656,7 +1338,7 @@ class TokenizerManager:
1656
1338
  i,
1657
1339
  )
1658
1340
 
1659
- if not isinstance(recv_obj, BatchEmbeddingOut):
1341
+ if not isinstance(recv_obj, BatchEmbeddingOutput):
1660
1342
  meta_info.update(
1661
1343
  {
1662
1344
  "completion_tokens": recv_obj.completion_tokens[i],
@@ -1667,7 +1349,7 @@ class TokenizerManager:
1667
1349
  if getattr(recv_obj, "output_hidden_states", None):
1668
1350
  meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
1669
1351
 
1670
- if isinstance(recv_obj, BatchStrOut):
1352
+ if isinstance(recv_obj, BatchStrOutput):
1671
1353
  state.text += recv_obj.output_strs[i]
1672
1354
  if state.obj.stream:
1673
1355
  state.output_ids.extend(recv_obj.output_ids[i])
@@ -1682,7 +1364,7 @@ class TokenizerManager:
1682
1364
  "output_ids": output_token_ids,
1683
1365
  "meta_info": meta_info,
1684
1366
  }
1685
- elif isinstance(recv_obj, BatchTokenIDOut):
1367
+ elif isinstance(recv_obj, BatchTokenIDOutput):
1686
1368
  if self.server_args.stream_output and state.obj.stream:
1687
1369
  state.output_ids.extend(recv_obj.output_ids[i])
1688
1370
  output_token_ids = state.output_ids[state.last_output_offset :]
@@ -1695,10 +1377,10 @@ class TokenizerManager:
1695
1377
  "output_ids": output_token_ids,
1696
1378
  "meta_info": meta_info,
1697
1379
  }
1698
- elif isinstance(recv_obj, BatchMultimodalOut):
1380
+ elif isinstance(recv_obj, BatchMultimodalOutput):
1699
1381
  raise NotImplementedError("BatchMultimodalOut not implemented")
1700
1382
  else:
1701
- assert isinstance(recv_obj, BatchEmbeddingOut)
1383
+ assert isinstance(recv_obj, BatchEmbeddingOutput)
1702
1384
  out_dict = {
1703
1385
  "embedding": recv_obj.embeddings[i],
1704
1386
  "meta_info": meta_info,
@@ -1710,6 +1392,9 @@ class TokenizerManager:
1710
1392
  meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1711
1393
  state.finished_time = time.time()
1712
1394
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1395
+
1396
+ trace_req_finish(rid, ts=int(state.finished_time * 1e9))
1397
+
1713
1398
  del self.rid_to_state[rid]
1714
1399
 
1715
1400
  # Mark ongoing LoRA request as finished.
@@ -1734,7 +1419,7 @@ class TokenizerManager:
1734
1419
  top_logprobs_num: int,
1735
1420
  token_ids_logprob: List[int],
1736
1421
  return_text_in_logprobs: bool,
1737
- recv_obj: BatchStrOut,
1422
+ recv_obj: BatchStrOutput,
1738
1423
  recv_obj_index: int,
1739
1424
  ):
1740
1425
  if recv_obj.input_token_logprobs_val is None:
@@ -1852,13 +1537,19 @@ class TokenizerManager:
1852
1537
  ret.append(None)
1853
1538
  return ret
1854
1539
 
1855
- def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
1540
+ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
1856
1541
  completion_tokens = (
1857
1542
  recv_obj.completion_tokens[i]
1858
1543
  if getattr(recv_obj, "completion_tokens", None)
1859
1544
  else 0
1860
1545
  )
1861
1546
 
1547
+ custom_labels = getattr(state.obj, "custom_labels", None)
1548
+ labels = (
1549
+ {**self.metrics_collector.labels, **custom_labels}
1550
+ if custom_labels
1551
+ else self.metrics_collector.labels
1552
+ )
1862
1553
  if (
1863
1554
  state.first_token_time == 0.0
1864
1555
  and self.disaggregation_mode != DisaggregationMode.PREFILL
@@ -1866,7 +1557,7 @@ class TokenizerManager:
1866
1557
  state.first_token_time = state.last_time = time.time()
1867
1558
  state.last_completion_tokens = completion_tokens
1868
1559
  self.metrics_collector.observe_time_to_first_token(
1869
- state.first_token_time - state.created_time
1560
+ labels, state.first_token_time - state.created_time
1870
1561
  )
1871
1562
  else:
1872
1563
  num_new_tokens = completion_tokens - state.last_completion_tokens
@@ -1874,6 +1565,7 @@ class TokenizerManager:
1874
1565
  new_time = time.time()
1875
1566
  interval = new_time - state.last_time
1876
1567
  self.metrics_collector.observe_inter_token_latency(
1568
+ labels,
1877
1569
  interval,
1878
1570
  num_new_tokens,
1879
1571
  )
@@ -1888,6 +1580,7 @@ class TokenizerManager:
1888
1580
  or state.obj.sampling_params.get("structural_tag", None)
1889
1581
  )
1890
1582
  self.metrics_collector.observe_one_finished_request(
1583
+ labels,
1891
1584
  recv_obj.prompt_tokens[i],
1892
1585
  completion_tokens,
1893
1586
  recv_obj.cached_tokens[i],
@@ -1940,7 +1633,7 @@ class TokenizerManager:
1940
1633
 
1941
1634
  asyncio.create_task(asyncio.to_thread(background_task))
1942
1635
 
1943
- def _handle_abort_req(self, recv_obj):
1636
+ def _handle_abort_req(self, recv_obj: AbortReq):
1944
1637
  if is_health_check_generate_req(recv_obj):
1945
1638
  return
1946
1639
  state = self.rid_to_state[recv_obj.rid]
@@ -2059,11 +1752,15 @@ class TokenizerManager:
2059
1752
  # the next position after the last token in the prompt
2060
1753
  output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
2061
1754
 
2062
- # Throw an error here if output_logprobs is None
2063
- if output_logprobs is None:
1755
+ # Check if output_logprobs is properly populated
1756
+ if (
1757
+ output_logprobs is None
1758
+ or not output_logprobs
1759
+ or len(output_logprobs) == 0
1760
+ ):
2064
1761
  raise RuntimeError(
2065
- f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
2066
- "This usually indicates a problem with the scoring request or the backend output."
1762
+ f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
1763
+ "This indicates token_ids_logprobs were not computed properly for the scoring request."
2067
1764
  )
2068
1765
 
2069
1766
  for logprob, token_id, _ in output_logprobs[0]:
@@ -2088,6 +1785,43 @@ class TokenizerManager:
2088
1785
 
2089
1786
  return scores
2090
1787
 
1788
+ async def watch_load_thread(self):
1789
+ # Only for dp_controller when dp_size > 1
1790
+ if (
1791
+ self.server_args.dp_size == 1
1792
+ or self.server_args.load_balance_method == "round_robin"
1793
+ ):
1794
+ return
1795
+
1796
+ while True:
1797
+ await asyncio.sleep(self.server_args.load_watch_interval)
1798
+ loads = await self.get_load_communicator(GetLoadReqInput())
1799
+ load_udpate_req = WatchLoadUpdateReq(loads=loads)
1800
+ self.send_to_scheduler.send_pyobj(load_udpate_req)
1801
+
1802
+ def _trace_request_start(
1803
+ self,
1804
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
1805
+ created_time: Optional[float] = None,
1806
+ ):
1807
+ if obj.is_single:
1808
+ bootstrap_room = (
1809
+ obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
1810
+ )
1811
+ trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
1812
+ trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
1813
+ else:
1814
+ for i in range(len(obj.rid)):
1815
+ bootstrap_room = (
1816
+ obj.bootstrap_room[i]
1817
+ if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
1818
+ else None
1819
+ )
1820
+ trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
1821
+ trace_slice_start(
1822
+ "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
1823
+ )
1824
+
2091
1825
 
2092
1826
  class ServerStatus(Enum):
2093
1827
  Up = "Up"
@@ -2133,57 +1867,12 @@ class SignalHandler:
2133
1867
 
2134
1868
  def running_phase_sigquit_handler(self, signum=None, frame=None):
2135
1869
  logger.error(
2136
- "Received sigquit from a child process. It usually means the child failed."
1870
+ f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
2137
1871
  )
2138
1872
  self.tokenizer_manager.dump_requests_before_crash()
2139
1873
  kill_process_tree(os.getpid())
2140
1874
 
2141
1875
 
2142
- T = TypeVar("T")
2143
-
2144
-
2145
- class _Communicator(Generic[T]):
2146
- """Note: The communicator now only run up to 1 in-flight request at any time."""
2147
-
2148
- enable_multi_tokenizer = False
2149
-
2150
- def __init__(self, sender, fan_out: int):
2151
- self._sender = sender
2152
- self._fan_out = fan_out
2153
- self._result_event: Optional[asyncio.Event] = None
2154
- self._result_values: Optional[List[T]] = None
2155
- self._ready_queue: Deque[asyncio.Future] = deque()
2156
-
2157
- async def __call__(self, obj):
2158
- ready_event = asyncio.Event()
2159
- if self._result_event is not None or len(self._ready_queue) > 0:
2160
- self._ready_queue.append(ready_event)
2161
- await ready_event.wait()
2162
- assert self._result_event is None
2163
- assert self._result_values is None
2164
-
2165
- if obj:
2166
- if _Communicator.enable_multi_tokenizer:
2167
- obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
2168
- self._sender.send_pyobj(obj)
2169
-
2170
- self._result_event = asyncio.Event()
2171
- self._result_values = []
2172
- await self._result_event.wait()
2173
- result_values = self._result_values
2174
- self._result_event = self._result_values = None
2175
-
2176
- if len(self._ready_queue) > 0:
2177
- self._ready_queue.popleft().set()
2178
-
2179
- return result_values
2180
-
2181
- def handle_recv(self, recv_obj: T):
2182
- self._result_values.append(recv_obj)
2183
- if len(self._result_values) == self._fan_out:
2184
- self._result_event.set()
2185
-
2186
-
2187
1876
  # Note: request abort handling logic
2188
1877
  # We should handle all of the following cases correctly.
2189
1878
  #