sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
31
31
  from datetime import datetime
32
32
  from enum import Enum
33
33
  from http import HTTPStatus
34
- from typing import (
35
- Any,
36
- Awaitable,
37
- Deque,
38
- Dict,
39
- Generic,
40
- List,
41
- Optional,
42
- Tuple,
43
- TypeVar,
44
- Union,
45
- )
34
+ from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
46
35
 
47
36
  import fastapi
48
37
  import torch
@@ -53,80 +42,49 @@ from fastapi import BackgroundTasks
53
42
 
54
43
  from sglang.srt.aio_rwlock import RWLock
55
44
  from sglang.srt.configs.model_config import ModelConfig
56
- from sglang.srt.disaggregation.utils import (
57
- DisaggregationMode,
58
- KVClassType,
59
- TransferBackend,
60
- get_kv_class,
61
- )
62
- from sglang.srt.hf_transformers_utils import (
63
- get_processor,
64
- get_tokenizer,
65
- get_tokenizer_from_processor,
66
- )
67
- from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
45
+ from sglang.srt.disaggregation.utils import DisaggregationMode
46
+ from sglang.srt.lora.lora_registry import LoRARegistry
47
+ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
48
+ from sglang.srt.managers.disagg_service import start_disagg_service
68
49
  from sglang.srt.managers.io_struct import (
69
50
  AbortReq,
70
- BatchEmbeddingOut,
71
- BatchMultimodalOut,
72
- BatchStrOut,
73
- BatchTokenIDOut,
51
+ BatchEmbeddingOutput,
52
+ BatchMultimodalOutput,
53
+ BatchStrOutput,
54
+ BatchTokenIDOutput,
74
55
  BatchTokenizedEmbeddingReqInput,
75
56
  BatchTokenizedGenerateReqInput,
76
- ClearHiCacheReqInput,
77
- ClearHiCacheReqOutput,
78
- CloseSessionReqInput,
79
57
  ConfigureLoggingReq,
80
58
  EmbeddingReqInput,
81
- ExpertDistributionReq,
82
- ExpertDistributionReqOutput,
83
- FlushCacheReqInput,
84
- FlushCacheReqOutput,
85
59
  FreezeGCReq,
86
60
  GenerateReqInput,
87
- GetInternalStateReq,
88
- GetInternalStateReqOutput,
89
- GetWeightsByNameReqInput,
90
- GetWeightsByNameReqOutput,
61
+ GetLoadReqInput,
91
62
  HealthCheckOutput,
92
- InitWeightsUpdateGroupReqInput,
93
- InitWeightsUpdateGroupReqOutput,
94
- LoadLoRAAdapterReqInput,
95
- LoadLoRAAdapterReqOutput,
96
- LoRAUpdateResult,
97
- MultiTokenizerWarpper,
98
- OpenSessionReqInput,
63
+ MultiTokenizerWrapper,
99
64
  OpenSessionReqOutput,
100
- ProfileReq,
101
- ProfileReqOutput,
102
- ProfileReqType,
103
- ReleaseMemoryOccupationReqInput,
104
- ReleaseMemoryOccupationReqOutput,
105
- ResumeMemoryOccupationReqInput,
106
- ResumeMemoryOccupationReqOutput,
107
65
  SessionParams,
108
- SetInternalStateReq,
109
- SetInternalStateReqOutput,
110
- SlowDownReqInput,
111
- SlowDownReqOutput,
112
66
  TokenizedEmbeddingReqInput,
113
67
  TokenizedGenerateReqInput,
114
- UnloadLoRAAdapterReqInput,
115
- UnloadLoRAAdapterReqOutput,
116
68
  UpdateWeightFromDiskReqInput,
117
69
  UpdateWeightFromDiskReqOutput,
118
- UpdateWeightsFromDistributedReqInput,
119
- UpdateWeightsFromDistributedReqOutput,
120
- UpdateWeightsFromTensorReqInput,
121
- UpdateWeightsFromTensorReqOutput,
70
+ WatchLoadUpdateReq,
122
71
  )
123
72
  from sglang.srt.managers.mm_utils import TensorTransportMode
124
73
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
125
74
  from sglang.srt.managers.scheduler import is_health_check_generate_req
126
75
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
76
+ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
127
77
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
128
78
  from sglang.srt.sampling.sampling_params import SamplingParams
129
79
  from sglang.srt.server_args import PortArgs, ServerArgs
80
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
81
+ from sglang.srt.tracing.trace import (
82
+ trace_get_proc_propagate_context,
83
+ trace_req_finish,
84
+ trace_req_start,
85
+ trace_slice_end,
86
+ trace_slice_start,
87
+ )
130
88
  from sglang.srt.utils import (
131
89
  configure_gc_warning,
132
90
  dataclass_to_string_truncated,
@@ -136,6 +94,11 @@ from sglang.srt.utils import (
136
94
  get_zmq_socket,
137
95
  kill_process_tree,
138
96
  )
97
+ from sglang.srt.utils.hf_transformers_utils import (
98
+ get_processor,
99
+ get_tokenizer,
100
+ get_tokenizer_from_processor,
101
+ )
139
102
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
140
103
 
141
104
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -180,7 +143,7 @@ class ReqState:
180
143
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
181
144
 
182
145
 
183
- class TokenizerManager:
146
+ class TokenizerManager(TokenizerCommunicatorMixin):
184
147
  """TokenizerManager is a process that tokenizes the text."""
185
148
 
186
149
  def __init__(
@@ -199,6 +162,7 @@ class TokenizerManager:
199
162
  else None
200
163
  )
201
164
  self.crash_dump_folder = server_args.crash_dump_folder
165
+ self.enable_trace = server_args.enable_trace
202
166
 
203
167
  # Read model args
204
168
  self.model_path = server_args.model_path
@@ -210,8 +174,17 @@ class TokenizerManager:
210
174
  self.image_token_id = self.model_config.image_token_id
211
175
  self.max_req_input_len = None # Will be set later in engine.py
212
176
 
177
+ speculative_algorithm = SpeculativeAlgorithm.from_string(
178
+ server_args.speculative_algorithm
179
+ )
180
+ self.reserve_input_token_num = (
181
+ 0
182
+ if speculative_algorithm.is_none()
183
+ else server_args.speculative_num_draft_tokens
184
+ )
185
+
213
186
  if self.model_config.is_multimodal:
214
- import_processors()
187
+ import_processors("sglang.srt.multimodal.processors")
215
188
  try:
216
189
  _processor = get_processor(
217
190
  server_args.tokenizer_path,
@@ -262,6 +235,18 @@ class TokenizerManager:
262
235
  trust_remote_code=server_args.trust_remote_code,
263
236
  revision=server_args.revision,
264
237
  )
238
+ # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
239
+ if (
240
+ server_args.enable_dynamic_batch_tokenizer
241
+ and not server_args.skip_tokenizer_init
242
+ ):
243
+ self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
244
+ self.tokenizer,
245
+ max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
246
+ batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
247
+ )
248
+ else:
249
+ self.async_dynamic_batch_tokenizer = None
265
250
 
266
251
  # Init inter-process communication
267
252
  context = zmq.asyncio.Context(2)
@@ -319,8 +304,10 @@ class TokenizerManager:
319
304
  # LoRA updates and inference to overlap.
320
305
  self.lora_update_lock = asyncio.Lock()
321
306
 
322
- # For PD disaggregtion
323
- self.init_disaggregation()
307
+ self.disaggregation_mode = DisaggregationMode(
308
+ self.server_args.disaggregation_mode
309
+ )
310
+ self.bootstrap_server = start_disagg_service(self.server_args)
324
311
 
325
312
  # For load balancing
326
313
  self.current_load = 0
@@ -328,12 +315,16 @@ class TokenizerManager:
328
315
 
329
316
  # Metrics
330
317
  if self.enable_metrics:
318
+ labels = {
319
+ "model_name": self.server_args.served_model_name,
320
+ # TODO: Add lora name/path in the future,
321
+ }
322
+ if server_args.tokenizer_metrics_allowed_custom_labels:
323
+ for label in server_args.tokenizer_metrics_allowed_custom_labels:
324
+ labels[label] = ""
331
325
  self.metrics_collector = TokenizerMetricsCollector(
332
326
  server_args=server_args,
333
- labels={
334
- "model_name": self.server_args.served_model_name,
335
- # TODO: Add lora name/path in the future,
336
- },
327
+ labels=labels,
337
328
  bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
338
329
  bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
339
330
  bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
@@ -344,58 +335,14 @@ class TokenizerManager:
344
335
  if self.server_args.gc_warning_threshold_secs > 0.0:
345
336
  configure_gc_warning(self.server_args.gc_warning_threshold_secs)
346
337
 
347
- # Communicators
348
- self.init_weights_update_group_communicator = _Communicator(
349
- self.send_to_scheduler, server_args.dp_size
350
- )
351
- self.update_weights_from_distributed_communicator = _Communicator(
352
- self.send_to_scheduler, server_args.dp_size
353
- )
354
- self.update_weights_from_tensor_communicator = _Communicator(
355
- self.send_to_scheduler, server_args.dp_size
356
- )
357
- self.get_weights_by_name_communicator = _Communicator(
358
- self.send_to_scheduler, server_args.dp_size
359
- )
360
- self.release_memory_occupation_communicator = _Communicator(
361
- self.send_to_scheduler, server_args.dp_size
362
- )
363
- self.resume_memory_occupation_communicator = _Communicator(
364
- self.send_to_scheduler, server_args.dp_size
365
- )
366
- self.slow_down_communicator = _Communicator(
367
- self.send_to_scheduler, server_args.dp_size
368
- )
369
- self.flush_cache_communicator = _Communicator(
370
- self.send_to_scheduler, server_args.dp_size
371
- )
372
- self.clear_hicache_storage_communicator = _Communicator(
373
- self.send_to_scheduler, server_args.dp_size
374
- )
375
- self.profile_communicator = _Communicator(
376
- self.send_to_scheduler, server_args.dp_size
377
- )
378
- self.get_internal_state_communicator = _Communicator(
379
- self.send_to_scheduler, server_args.dp_size
380
- )
381
- self.set_internal_state_communicator = _Communicator(
382
- self.send_to_scheduler, server_args.dp_size
383
- )
384
- self.expert_distribution_communicator = _Communicator(
385
- self.send_to_scheduler, server_args.dp_size
386
- )
387
- self.update_lora_adapter_communicator = _Communicator(
388
- self.send_to_scheduler, server_args.dp_size
389
- )
390
-
391
338
  self._result_dispatcher = TypeBasedDispatcher(
392
339
  [
393
340
  (
394
341
  (
395
- BatchStrOut,
396
- BatchEmbeddingOut,
397
- BatchTokenIDOut,
398
- BatchMultimodalOut,
342
+ BatchStrOutput,
343
+ BatchEmbeddingOutput,
344
+ BatchTokenIDOutput,
345
+ BatchMultimodalOutput,
399
346
  ),
400
347
  self._handle_batch_output,
401
348
  ),
@@ -405,100 +352,15 @@ class TokenizerManager:
405
352
  UpdateWeightFromDiskReqOutput,
406
353
  self._handle_update_weights_from_disk_req_output,
407
354
  ),
408
- (
409
- InitWeightsUpdateGroupReqOutput,
410
- self.init_weights_update_group_communicator.handle_recv,
411
- ),
412
- (
413
- UpdateWeightsFromDistributedReqOutput,
414
- self.update_weights_from_distributed_communicator.handle_recv,
415
- ),
416
- (
417
- UpdateWeightsFromTensorReqOutput,
418
- self.update_weights_from_tensor_communicator.handle_recv,
419
- ),
420
- (
421
- GetWeightsByNameReqOutput,
422
- self.get_weights_by_name_communicator.handle_recv,
423
- ),
424
- (
425
- ReleaseMemoryOccupationReqOutput,
426
- self.release_memory_occupation_communicator.handle_recv,
427
- ),
428
- (
429
- ResumeMemoryOccupationReqOutput,
430
- self.resume_memory_occupation_communicator.handle_recv,
431
- ),
432
- (
433
- SlowDownReqOutput,
434
- self.slow_down_communicator.handle_recv,
435
- ),
436
- (
437
- ClearHiCacheReqOutput,
438
- self.clear_hicache_storage_communicator.handle_recv,
439
- ),
440
- (
441
- FlushCacheReqOutput,
442
- self.flush_cache_communicator.handle_recv,
443
- ),
444
- (
445
- ProfileReqOutput,
446
- self.profile_communicator.handle_recv,
447
- ),
448
355
  (
449
356
  FreezeGCReq,
450
357
  lambda x: None,
451
358
  ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
452
- (
453
- GetInternalStateReqOutput,
454
- self.get_internal_state_communicator.handle_recv,
455
- ),
456
- (
457
- SetInternalStateReqOutput,
458
- self.set_internal_state_communicator.handle_recv,
459
- ),
460
- (
461
- ExpertDistributionReqOutput,
462
- self.expert_distribution_communicator.handle_recv,
463
- ),
464
- (
465
- LoRAUpdateResult,
466
- self.update_lora_adapter_communicator.handle_recv,
467
- ),
468
359
  (HealthCheckOutput, lambda x: None),
469
360
  ]
470
361
  )
471
362
 
472
- def init_disaggregation(self):
473
- self.disaggregation_mode = DisaggregationMode(
474
- self.server_args.disaggregation_mode
475
- )
476
- self.disaggregation_transfer_backend = TransferBackend(
477
- self.server_args.disaggregation_transfer_backend
478
- )
479
- # Start kv boostrap server on prefill
480
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
481
- # only start bootstrap server on prefill tm
482
- kv_bootstrap_server_class = get_kv_class(
483
- self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
484
- )
485
- self.bootstrap_server = kv_bootstrap_server_class(
486
- self.server_args.disaggregation_bootstrap_port
487
- )
488
- is_create_store = (
489
- self.server_args.node_rank == 0
490
- and self.server_args.disaggregation_transfer_backend == "ascend"
491
- )
492
- if is_create_store:
493
- try:
494
- from mf_adapter import create_config_store
495
-
496
- ascend_url = os.getenv("ASCEND_MF_STORE_URL")
497
- create_config_store(ascend_url)
498
- except Exception as e:
499
- error_message = f"Failed create mf store, invalid ascend_url."
500
- error_message += f" With exception {e}"
501
- raise error_message
363
+ self.init_communicators(server_args)
502
364
 
503
365
  async def generate_request(
504
366
  self,
@@ -518,6 +380,9 @@ class TokenizerManager:
518
380
  # If it's a single value, add worker_id prefix
519
381
  obj.rid = f"{self.worker_id}_{obj.rid}"
520
382
 
383
+ if self.enable_trace:
384
+ self._trace_request_start(obj, created_time)
385
+
521
386
  if self.log_requests:
522
387
  max_length, skip_names, _ = self.log_request_metadata
523
388
  logger.info(
@@ -543,6 +408,144 @@ class TokenizerManager:
543
408
  ):
544
409
  yield response
545
410
 
411
+ def _detect_input_format(
412
+ self, texts: Union[str, List[str]], is_cross_encoder: bool
413
+ ) -> str:
414
+ """Detect the format of input texts for proper tokenization handling.
415
+
416
+ Returns:
417
+ - "single_string": Regular single text like "Hello world"
418
+ - "batch_strings": Regular batch like ["Hello", "World"]
419
+ - "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
420
+ """
421
+ if isinstance(texts, str):
422
+ return "single_string"
423
+
424
+ if (
425
+ is_cross_encoder
426
+ and len(texts) > 0
427
+ and isinstance(texts[0], list)
428
+ and len(texts[0]) == 2
429
+ ):
430
+ return "cross_encoder_pairs"
431
+
432
+ return "batch_strings"
433
+
434
+ def _prepare_tokenizer_input(
435
+ self, texts: Union[str, List[str]], input_format: str
436
+ ) -> Union[List[str], List[List[str]]]:
437
+ """Prepare input for the tokenizer based on detected format."""
438
+ if input_format == "single_string":
439
+ return [texts] # Wrap single string for batch processing
440
+ elif input_format == "cross_encoder_pairs":
441
+ return texts # Already in correct format: [["query", "doc"]]
442
+ else: # batch_strings
443
+ return texts # Already in correct format: ["text1", "text2"]
444
+
445
+ def _extract_tokenizer_results(
446
+ self,
447
+ input_ids: List[List[int]],
448
+ token_type_ids: Optional[List[List[int]]],
449
+ input_format: str,
450
+ original_batch_size: int,
451
+ ) -> Union[
452
+ Tuple[List[int], Optional[List[int]]],
453
+ Tuple[List[List[int]], Optional[List[List[int]]]],
454
+ ]:
455
+ """Extract results from tokenizer output based on input format."""
456
+
457
+ # For single inputs (string or single cross-encoder pair), extract first element
458
+ if (
459
+ input_format in ["single_string", "cross_encoder_pairs"]
460
+ and original_batch_size == 1
461
+ ):
462
+ single_input_ids = input_ids[0] if input_ids else []
463
+ single_token_type_ids = token_type_ids[0] if token_type_ids else None
464
+ return single_input_ids, single_token_type_ids
465
+
466
+ # For true batches, return as-is
467
+ return input_ids, token_type_ids
468
+
469
+ async def _tokenize_texts(
470
+ self, texts: Union[str, List[str]], is_cross_encoder: bool = False
471
+ ) -> Union[
472
+ Tuple[List[int], Optional[List[int]]],
473
+ Tuple[List[List[int]], Optional[List[List[int]]]],
474
+ ]:
475
+ """
476
+ Tokenize text(s) using the appropriate tokenizer strategy.
477
+
478
+ This method handles multiple input formats and chooses between async dynamic
479
+ batch tokenizer (for single texts only) and regular tokenizer.
480
+
481
+ Args:
482
+ texts: Text input in various formats:
483
+
484
+ Regular cases:
485
+ - Single string: "How are you?"
486
+ - Batch of strings: ["Hello", "World", "How are you?"]
487
+
488
+ Cross-encoder cases (sentence pairs for similarity/ranking):
489
+ - Single pair: [["query text", "document text"]]
490
+ - Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
491
+
492
+ is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
493
+ Enables proper handling of sentence pairs with segment IDs.
494
+
495
+ Returns:
496
+ Single input cases:
497
+ Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
498
+ Example: ([101, 2129, 102], [0, 0, 0]) for single text
499
+ Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
500
+
501
+ Batch input cases:
502
+ Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
503
+ Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
504
+
505
+ Note: token_type_ids is None unless is_cross_encoder=True.
506
+ """
507
+ if not texts or self.tokenizer is None:
508
+ raise ValueError("texts cannot be empty and tokenizer must be initialized")
509
+
510
+ # Step 1: Detect input format and prepare for tokenization
511
+ input_format = self._detect_input_format(texts, is_cross_encoder)
512
+ tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
513
+ original_batch_size = len(texts) if not isinstance(texts, str) else 1
514
+
515
+ # Step 2: Set up tokenizer arguments
516
+ tokenizer_kwargs = (
517
+ {"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
518
+ )
519
+
520
+ # Step 3: Choose tokenization strategy
521
+ use_async_tokenizer = (
522
+ self.async_dynamic_batch_tokenizer is not None
523
+ and input_format == "single_string"
524
+ )
525
+
526
+ if use_async_tokenizer:
527
+ logger.debug("Using async dynamic batch tokenizer for single text")
528
+ result = await self.async_dynamic_batch_tokenizer.encode(
529
+ tokenizer_input[0], **tokenizer_kwargs
530
+ )
531
+ # Convert to batch format for consistency
532
+ input_ids = [result["input_ids"]]
533
+ token_type_ids = (
534
+ [result["token_type_ids"]]
535
+ if is_cross_encoder and result.get("token_type_ids")
536
+ else None
537
+ )
538
+ else:
539
+ logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
540
+ encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
541
+ input_ids = encoded["input_ids"]
542
+ token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
543
+
544
+ # Step 4: Extract results based on input format
545
+ return self._extract_tokenizer_results(
546
+ input_ids, token_type_ids, input_format, original_batch_size
547
+ )
548
+
546
549
  async def _tokenize_one_request(
547
550
  self,
548
551
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -573,14 +576,10 @@ class TokenizerManager:
573
576
  "accept text prompts. Please provide input_ids or re-initialize "
574
577
  "the engine with skip_tokenizer_init=False."
575
578
  )
576
- encoded = self.tokenizer(
577
- input_text, return_token_type_ids=is_cross_encoder_request
578
- )
579
579
 
580
- input_ids = encoded["input_ids"]
581
- if is_cross_encoder_request:
582
- input_ids = encoded["input_ids"][0]
583
- token_type_ids = encoded.get("token_type_ids", [None])[0]
580
+ input_ids, token_type_ids = await self._tokenize_texts(
581
+ input_text, is_cross_encoder_request
582
+ )
584
583
 
585
584
  if self.mm_processor and obj.contains_mm_input():
586
585
  if not isinstance(obj.image_data, list):
@@ -600,6 +599,7 @@ class TokenizerManager:
600
599
  mm_inputs = None
601
600
 
602
601
  self._validate_one_request(obj, input_ids)
602
+ trace_slice_end("tokenize", obj.rid)
603
603
  return self._create_tokenized_object(
604
604
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
605
605
  )
@@ -612,6 +612,7 @@ class TokenizerManager:
612
612
  _max_req_len = self.context_len
613
613
 
614
614
  input_token_num = len(input_ids) if input_ids is not None else 0
615
+ input_token_num += self.reserve_input_token_num
615
616
  if input_token_num >= self.context_len:
616
617
  if self.server_args.allow_auto_truncate:
617
618
  logger.warning(
@@ -674,7 +675,7 @@ class TokenizerManager:
674
675
  ):
675
676
  raise ValueError(
676
677
  "The server is not configured to enable custom logit processor. "
677
- "Please set `--enable-custom-logits-processor` to enable this feature."
678
+ "Please set `--enable-custom-logit-processor` to enable this feature."
678
679
  )
679
680
 
680
681
  def _validate_input_ids_in_vocab(
@@ -713,7 +714,6 @@ class TokenizerManager:
713
714
  )
714
715
 
715
716
  tokenized_obj = TokenizedGenerateReqInput(
716
- obj.rid,
717
717
  input_text,
718
718
  input_ids,
719
719
  mm_inputs,
@@ -723,6 +723,7 @@ class TokenizerManager:
723
723
  obj.top_logprobs_num,
724
724
  obj.token_ids_logprob,
725
725
  obj.stream,
726
+ rid=obj.rid,
726
727
  bootstrap_host=obj.bootstrap_host,
727
728
  bootstrap_port=obj.bootstrap_port,
728
729
  bootstrap_room=obj.bootstrap_room,
@@ -732,15 +733,18 @@ class TokenizerManager:
732
733
  custom_logit_processor=obj.custom_logit_processor,
733
734
  return_hidden_states=obj.return_hidden_states,
734
735
  data_parallel_rank=obj.data_parallel_rank,
736
+ priority=obj.priority,
737
+ extra_key=obj.extra_key,
735
738
  )
736
739
  elif isinstance(obj, EmbeddingReqInput):
737
740
  tokenized_obj = TokenizedEmbeddingReqInput(
738
- obj.rid,
739
741
  input_text,
740
742
  input_ids,
741
743
  mm_inputs,
742
744
  token_type_ids,
743
745
  sampling_params,
746
+ rid=obj.rid,
747
+ priority=obj.priority,
744
748
  )
745
749
 
746
750
  return tokenized_obj
@@ -755,19 +759,30 @@ class TokenizerManager:
755
759
  requests = [obj[i] for i in range(batch_size)]
756
760
  texts = [req.text for req in requests]
757
761
 
758
- # Batch tokenize all texts
759
- encoded = self.tokenizer(texts)
760
- input_ids_list = encoded["input_ids"]
762
+ # Check if any request is a cross-encoder request
763
+ is_cross_encoder_request = any(
764
+ isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
765
+ for req in requests
766
+ )
767
+
768
+ # Batch tokenize all texts using unified method
769
+ input_ids_list, token_type_ids_list = await self._tokenize_texts(
770
+ texts, is_cross_encoder_request
771
+ )
761
772
 
762
773
  # Process all requests
763
774
  tokenized_objs = []
764
775
  for i, req in enumerate(requests):
765
776
  self._validate_one_request(obj[i], input_ids_list[i])
777
+ token_type_ids = (
778
+ token_type_ids_list[i] if token_type_ids_list is not None else None
779
+ )
766
780
  tokenized_objs.append(
767
781
  self._create_tokenized_object(
768
- req, req.text, input_ids_list[i], None, None
782
+ req, req.text, input_ids_list[i], None, None, token_type_ids
769
783
  )
770
784
  )
785
+ trace_slice_end("tokenize", req.rid)
771
786
  logger.debug(f"Completed batch processing for {batch_size} requests")
772
787
  return tokenized_objs
773
788
 
@@ -795,9 +810,12 @@ class TokenizerManager:
795
810
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
796
811
  created_time: Optional[float] = None,
797
812
  ):
813
+ trace_slice_start("dispatch", obj.rid)
814
+ tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
798
815
  self.send_to_scheduler.send_pyobj(tokenized_obj)
799
816
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
800
817
  self.rid_to_state[obj.rid] = state
818
+ trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
801
819
  return state
802
820
 
803
821
  def _send_batch_request(
@@ -1015,73 +1033,16 @@ class TokenizerManager:
1015
1033
  except StopAsyncIteration:
1016
1034
  pass
1017
1035
 
1018
- async def flush_cache(self) -> FlushCacheReqOutput:
1019
- return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
1020
-
1021
- async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
1022
- """Clear the hierarchical cache storage."""
1023
- # Delegate to the scheduler to handle HiCacheStorage clearing
1024
- return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
1025
- 0
1026
- ]
1027
-
1028
1036
  def abort_request(self, rid: str = "", abort_all: bool = False):
1029
1037
  if not abort_all and rid not in self.rid_to_state:
1030
1038
  return
1031
- req = AbortReq(rid, abort_all)
1039
+ req = AbortReq(rid=rid, abort_all=abort_all)
1032
1040
  self.send_to_scheduler.send_pyobj(req)
1033
-
1034
1041
  if self.enable_metrics:
1035
- self.metrics_collector.observe_one_aborted_request()
1036
-
1037
- async def start_profile(
1038
- self,
1039
- output_dir: Optional[str] = None,
1040
- start_step: Optional[int] = None,
1041
- num_steps: Optional[int] = None,
1042
- activities: Optional[List[str]] = None,
1043
- with_stack: Optional[bool] = None,
1044
- record_shapes: Optional[bool] = None,
1045
- profile_by_stage: bool = False,
1046
- ):
1047
- self.auto_create_handle_loop()
1048
- env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
1049
- with_stack = False if with_stack is False or env_with_stack is False else True
1050
- req = ProfileReq(
1051
- type=ProfileReqType.START_PROFILE,
1052
- output_dir=output_dir,
1053
- start_step=start_step,
1054
- num_steps=num_steps,
1055
- activities=activities,
1056
- with_stack=with_stack,
1057
- record_shapes=record_shapes,
1058
- profile_by_stage=profile_by_stage,
1059
- profile_id=str(time.time()),
1060
- )
1061
- return await self._execute_profile(req)
1062
-
1063
- async def stop_profile(self):
1064
- self.auto_create_handle_loop()
1065
- req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
1066
- return await self._execute_profile(req)
1067
-
1068
- async def _execute_profile(self, req: ProfileReq):
1069
- result = (await self.profile_communicator(req))[0]
1070
- if not result.success:
1071
- raise RuntimeError(result.message)
1072
- return result
1073
-
1074
- async def start_expert_distribution_record(self):
1075
- self.auto_create_handle_loop()
1076
- await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
1077
-
1078
- async def stop_expert_distribution_record(self):
1079
- self.auto_create_handle_loop()
1080
- await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
1081
-
1082
- async def dump_expert_distribution_record(self):
1083
- self.auto_create_handle_loop()
1084
- await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
1042
+ # TODO: also use custom_labels from the request
1043
+ self.metrics_collector.observe_one_aborted_request(
1044
+ self.metrics_collector.labels
1045
+ )
1085
1046
 
1086
1047
  async def pause_generation(self):
1087
1048
  async with self.is_pause_cond:
@@ -1118,7 +1079,7 @@ class TokenizerManager:
1118
1079
  self, obj: UpdateWeightFromDiskReqInput
1119
1080
  ) -> Tuple[bool, str]:
1120
1081
  if self.server_args.tokenizer_worker_num > 1:
1121
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1082
+ obj = MultiTokenizerWrapper(self.worker_id, obj)
1122
1083
  self.send_to_scheduler.send_pyobj(obj)
1123
1084
  self.model_update_result = asyncio.Future()
1124
1085
  if self.server_args.dp_size == 1:
@@ -1143,291 +1104,6 @@ class TokenizerManager:
1143
1104
  all_paused_requests = [r.num_paused_requests for r in result]
1144
1105
  return all_success, all_message, all_paused_requests
1145
1106
 
1146
- async def init_weights_update_group(
1147
- self,
1148
- obj: InitWeightsUpdateGroupReqInput,
1149
- request: Optional[fastapi.Request] = None,
1150
- ) -> Tuple[bool, str]:
1151
- self.auto_create_handle_loop()
1152
- assert (
1153
- self.server_args.dp_size == 1
1154
- ), "dp_size must be 1 for init parameter update group"
1155
- result = (await self.init_weights_update_group_communicator(obj))[0]
1156
- return result.success, result.message
1157
-
1158
- async def update_weights_from_distributed(
1159
- self,
1160
- obj: UpdateWeightsFromDistributedReqInput,
1161
- request: Optional[fastapi.Request] = None,
1162
- ) -> Tuple[bool, str]:
1163
- self.auto_create_handle_loop()
1164
- assert (
1165
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1166
- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
1167
-
1168
- if obj.abort_all_requests:
1169
- self.abort_request(abort_all=True)
1170
-
1171
- # This means that weight sync
1172
- # cannot run while requests are in progress.
1173
- async with self.model_update_lock.writer_lock:
1174
- result = (await self.update_weights_from_distributed_communicator(obj))[0]
1175
- return result.success, result.message
1176
-
1177
- async def update_weights_from_tensor(
1178
- self,
1179
- obj: UpdateWeightsFromTensorReqInput,
1180
- request: Optional[fastapi.Request] = None,
1181
- ) -> Tuple[bool, str]:
1182
- self.auto_create_handle_loop()
1183
- assert (
1184
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1185
- ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
1186
-
1187
- if obj.abort_all_requests:
1188
- self.abort_request(abort_all=True)
1189
-
1190
- # This means that weight sync
1191
- # cannot run while requests are in progress.
1192
- async with self.model_update_lock.writer_lock:
1193
- result = (await self.update_weights_from_tensor_communicator(obj))[0]
1194
- return result.success, result.message
1195
-
1196
- async def load_lora_adapter(
1197
- self,
1198
- obj: LoadLoRAAdapterReqInput,
1199
- _: Optional[fastapi.Request] = None,
1200
- ) -> LoadLoRAAdapterReqOutput:
1201
- self.auto_create_handle_loop()
1202
-
1203
- try:
1204
- if not self.server_args.enable_lora:
1205
- raise ValueError(
1206
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1207
- )
1208
-
1209
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1210
- # with dp_size > 1.
1211
- assert (
1212
- self.server_args.dp_size == 1
1213
- ), "dp_size must be 1 for dynamic lora loading"
1214
- logger.info(
1215
- "Start load Lora adapter. Lora name=%s, path=%s",
1216
- obj.lora_name,
1217
- obj.lora_path,
1218
- )
1219
-
1220
- async with self.lora_update_lock:
1221
- if (
1222
- self.server_args.max_loaded_loras is not None
1223
- and self.lora_registry.num_registered_loras
1224
- >= self.server_args.max_loaded_loras
1225
- ):
1226
- raise ValueError(
1227
- f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1228
- f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1229
- "Please unload some LoRA adapters before loading new ones."
1230
- )
1231
-
1232
- # Generate new uniquely identifiable LoRARef object.
1233
- new_adapter = LoRARef(
1234
- lora_name=obj.lora_name,
1235
- lora_path=obj.lora_path,
1236
- pinned=obj.pinned,
1237
- )
1238
-
1239
- # Trigger the actual loading operation at the backend processes.
1240
- obj.lora_id = new_adapter.lora_id
1241
- result = (await self.update_lora_adapter_communicator(obj))[0]
1242
-
1243
- # Register the LoRA adapter only after loading is successful.
1244
- if result.success:
1245
- await self.lora_registry.register(new_adapter)
1246
-
1247
- return result
1248
- except ValueError as e:
1249
- return LoadLoRAAdapterReqOutput(
1250
- success=False,
1251
- error_message=str(e),
1252
- )
1253
-
1254
- async def unload_lora_adapter(
1255
- self,
1256
- obj: UnloadLoRAAdapterReqInput,
1257
- _: Optional[fastapi.Request] = None,
1258
- ) -> UnloadLoRAAdapterReqOutput:
1259
- self.auto_create_handle_loop()
1260
-
1261
- try:
1262
- if not self.server_args.enable_lora:
1263
- raise ValueError(
1264
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1265
- )
1266
-
1267
- assert (
1268
- obj.lora_name is not None
1269
- ), "lora_name must be provided to unload LoRA adapter"
1270
-
1271
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1272
- # with dp_size > 1.
1273
- assert (
1274
- self.server_args.dp_size == 1
1275
- ), "dp_size must be 1 for dynamic lora loading"
1276
- logger.info(
1277
- "Start unload Lora adapter. Lora name=%s",
1278
- obj.lora_name,
1279
- )
1280
-
1281
- async with self.lora_update_lock:
1282
- # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1283
- # from being started.
1284
- lora_id = await self.lora_registry.unregister(obj.lora_name)
1285
- obj.lora_id = lora_id
1286
-
1287
- # Initiate the actual unloading operation at the backend processes only after all
1288
- # ongoing requests using this LoRA adapter are finished.
1289
- await self.lora_registry.wait_for_unload(lora_id)
1290
- result = (await self.update_lora_adapter_communicator(obj))[0]
1291
-
1292
- return result
1293
- except ValueError as e:
1294
- return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1295
-
1296
- async def get_weights_by_name(
1297
- self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
1298
- ):
1299
- self.auto_create_handle_loop()
1300
- results = await self.get_weights_by_name_communicator(obj)
1301
- all_parameters = [r.parameter for r in results]
1302
- if self.server_args.dp_size == 1:
1303
- return all_parameters[0]
1304
- else:
1305
- return all_parameters
1306
-
1307
- async def release_memory_occupation(
1308
- self,
1309
- obj: ReleaseMemoryOccupationReqInput,
1310
- request: Optional[fastapi.Request] = None,
1311
- ):
1312
- self.auto_create_handle_loop()
1313
- await self.release_memory_occupation_communicator(obj)
1314
-
1315
- async def resume_memory_occupation(
1316
- self,
1317
- obj: ResumeMemoryOccupationReqInput,
1318
- request: Optional[fastapi.Request] = None,
1319
- ):
1320
- self.auto_create_handle_loop()
1321
- await self.resume_memory_occupation_communicator(obj)
1322
-
1323
- async def slow_down(
1324
- self,
1325
- obj: SlowDownReqInput,
1326
- request: Optional[fastapi.Request] = None,
1327
- ):
1328
- self.auto_create_handle_loop()
1329
- await self.slow_down_communicator(obj)
1330
-
1331
- async def open_session(
1332
- self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
1333
- ):
1334
- self.auto_create_handle_loop()
1335
-
1336
- if obj.session_id is None:
1337
- obj.session_id = uuid.uuid4().hex
1338
- elif obj.session_id in self.session_futures:
1339
- return None
1340
-
1341
- if self.server_args.tokenizer_worker_num > 1:
1342
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1343
- self.send_to_scheduler.send_pyobj(obj)
1344
-
1345
- self.session_futures[obj.session_id] = asyncio.Future()
1346
- session_id = await self.session_futures[obj.session_id]
1347
- del self.session_futures[obj.session_id]
1348
- return session_id
1349
-
1350
- async def close_session(
1351
- self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
1352
- ):
1353
- await self.send_to_scheduler.send_pyobj(obj)
1354
-
1355
- async def get_internal_state(self) -> List[Dict[Any, Any]]:
1356
- req = GetInternalStateReq()
1357
- responses: List[GetInternalStateReqOutput] = (
1358
- await self.get_internal_state_communicator(req)
1359
- )
1360
- # Many DP ranks
1361
- return [res.internal_state for res in responses]
1362
-
1363
- async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
1364
- responses: List[SetInternalStateReqOutput] = (
1365
- await self.set_internal_state_communicator(obj)
1366
- )
1367
- return [res.updated for res in responses]
1368
-
1369
- async def get_load(self) -> dict:
1370
- # TODO(lsyin): fake load report server
1371
- if not self.current_load_lock.locked():
1372
- async with self.current_load_lock:
1373
- internal_state = await self.get_internal_state()
1374
- self.current_load = internal_state[0]["load"]
1375
- return {"load": self.current_load}
1376
-
1377
- def get_log_request_metadata(self):
1378
- max_length = None
1379
- skip_names = None
1380
- out_skip_names = None
1381
- if self.log_requests:
1382
- if self.log_requests_level == 0:
1383
- max_length = 1 << 30
1384
- skip_names = set(
1385
- [
1386
- "text",
1387
- "input_ids",
1388
- "input_embeds",
1389
- "image_data",
1390
- "audio_data",
1391
- "lora_path",
1392
- "sampling_params",
1393
- ]
1394
- )
1395
- out_skip_names = set(
1396
- [
1397
- "text",
1398
- "output_ids",
1399
- "embedding",
1400
- ]
1401
- )
1402
- elif self.log_requests_level == 1:
1403
- max_length = 1 << 30
1404
- skip_names = set(
1405
- [
1406
- "text",
1407
- "input_ids",
1408
- "input_embeds",
1409
- "image_data",
1410
- "audio_data",
1411
- "lora_path",
1412
- ]
1413
- )
1414
- out_skip_names = set(
1415
- [
1416
- "text",
1417
- "output_ids",
1418
- "embedding",
1419
- ]
1420
- )
1421
- elif self.log_requests_level == 2:
1422
- max_length = 2048
1423
- elif self.log_requests_level == 3:
1424
- max_length = 1 << 30
1425
- else:
1426
- raise ValueError(
1427
- f"Invalid --log-requests-level: {self.log_requests_level=}"
1428
- )
1429
- return max_length, skip_names, out_skip_names
1430
-
1431
1107
  def configure_logging(self, obj: ConfigureLoggingReq):
1432
1108
  if obj.log_requests is not None:
1433
1109
  self.log_requests = obj.log_requests
@@ -1492,6 +1168,9 @@ class TokenizerManager:
1492
1168
  self.asyncio_tasks.add(
1493
1169
  loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
1494
1170
  )
1171
+ self.asyncio_tasks.add(
1172
+ loop.create_task(print_exception_wrapper(self.watch_load_thread))
1173
+ )
1495
1174
 
1496
1175
  def dump_requests_before_crash(self):
1497
1176
  if self.crash_dump_performed:
@@ -1583,12 +1262,12 @@ class TokenizerManager:
1583
1262
  # Drain requests
1584
1263
  while True:
1585
1264
  remain_num_req = len(self.rid_to_state)
1265
+ remaining_rids = list(self.rid_to_state.keys())
1586
1266
 
1587
1267
  if self.server_status == ServerStatus.UnHealthy:
1588
1268
  # if health check failed, we should exit immediately
1589
1269
  logger.error(
1590
- "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1591
- remain_num_req,
1270
+ "Signal SIGTERM received while health check failed. Force exiting."
1592
1271
  )
1593
1272
  self.dump_requests_before_crash()
1594
1273
  break
@@ -1596,13 +1275,12 @@ class TokenizerManager:
1596
1275
  elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
1597
1276
  # if force shutdown flag set, exit immediately
1598
1277
  logger.error(
1599
- "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
1600
- remain_num_req,
1278
+ "Signal SIGTERM received while force shutdown flag set. Force exiting."
1601
1279
  )
1602
1280
  break
1603
1281
 
1604
1282
  logger.info(
1605
- f"Gracefully exiting... remaining number of requests {remain_num_req}"
1283
+ f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
1606
1284
  )
1607
1285
  if remain_num_req > 0:
1608
1286
  await asyncio.sleep(5)
@@ -1623,7 +1301,10 @@ class TokenizerManager:
1623
1301
  def _handle_batch_output(
1624
1302
  self,
1625
1303
  recv_obj: Union[
1626
- BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
1304
+ BatchStrOutput,
1305
+ BatchEmbeddingOutput,
1306
+ BatchMultimodalOutput,
1307
+ BatchTokenIDOutput,
1627
1308
  ],
1628
1309
  ):
1629
1310
  for i, rid in enumerate(recv_obj.rids):
@@ -1657,7 +1338,7 @@ class TokenizerManager:
1657
1338
  i,
1658
1339
  )
1659
1340
 
1660
- if not isinstance(recv_obj, BatchEmbeddingOut):
1341
+ if not isinstance(recv_obj, BatchEmbeddingOutput):
1661
1342
  meta_info.update(
1662
1343
  {
1663
1344
  "completion_tokens": recv_obj.completion_tokens[i],
@@ -1668,7 +1349,7 @@ class TokenizerManager:
1668
1349
  if getattr(recv_obj, "output_hidden_states", None):
1669
1350
  meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
1670
1351
 
1671
- if isinstance(recv_obj, BatchStrOut):
1352
+ if isinstance(recv_obj, BatchStrOutput):
1672
1353
  state.text += recv_obj.output_strs[i]
1673
1354
  if state.obj.stream:
1674
1355
  state.output_ids.extend(recv_obj.output_ids[i])
@@ -1683,7 +1364,7 @@ class TokenizerManager:
1683
1364
  "output_ids": output_token_ids,
1684
1365
  "meta_info": meta_info,
1685
1366
  }
1686
- elif isinstance(recv_obj, BatchTokenIDOut):
1367
+ elif isinstance(recv_obj, BatchTokenIDOutput):
1687
1368
  if self.server_args.stream_output and state.obj.stream:
1688
1369
  state.output_ids.extend(recv_obj.output_ids[i])
1689
1370
  output_token_ids = state.output_ids[state.last_output_offset :]
@@ -1696,10 +1377,10 @@ class TokenizerManager:
1696
1377
  "output_ids": output_token_ids,
1697
1378
  "meta_info": meta_info,
1698
1379
  }
1699
- elif isinstance(recv_obj, BatchMultimodalOut):
1380
+ elif isinstance(recv_obj, BatchMultimodalOutput):
1700
1381
  raise NotImplementedError("BatchMultimodalOut not implemented")
1701
1382
  else:
1702
- assert isinstance(recv_obj, BatchEmbeddingOut)
1383
+ assert isinstance(recv_obj, BatchEmbeddingOutput)
1703
1384
  out_dict = {
1704
1385
  "embedding": recv_obj.embeddings[i],
1705
1386
  "meta_info": meta_info,
@@ -1711,6 +1392,9 @@ class TokenizerManager:
1711
1392
  meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1712
1393
  state.finished_time = time.time()
1713
1394
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1395
+
1396
+ trace_req_finish(rid, ts=int(state.finished_time * 1e9))
1397
+
1714
1398
  del self.rid_to_state[rid]
1715
1399
 
1716
1400
  # Mark ongoing LoRA request as finished.
@@ -1735,7 +1419,7 @@ class TokenizerManager:
1735
1419
  top_logprobs_num: int,
1736
1420
  token_ids_logprob: List[int],
1737
1421
  return_text_in_logprobs: bool,
1738
- recv_obj: BatchStrOut,
1422
+ recv_obj: BatchStrOutput,
1739
1423
  recv_obj_index: int,
1740
1424
  ):
1741
1425
  if recv_obj.input_token_logprobs_val is None:
@@ -1853,13 +1537,19 @@ class TokenizerManager:
1853
1537
  ret.append(None)
1854
1538
  return ret
1855
1539
 
1856
- def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
1540
+ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
1857
1541
  completion_tokens = (
1858
1542
  recv_obj.completion_tokens[i]
1859
1543
  if getattr(recv_obj, "completion_tokens", None)
1860
1544
  else 0
1861
1545
  )
1862
1546
 
1547
+ custom_labels = getattr(state.obj, "custom_labels", None)
1548
+ labels = (
1549
+ {**self.metrics_collector.labels, **custom_labels}
1550
+ if custom_labels
1551
+ else self.metrics_collector.labels
1552
+ )
1863
1553
  if (
1864
1554
  state.first_token_time == 0.0
1865
1555
  and self.disaggregation_mode != DisaggregationMode.PREFILL
@@ -1867,7 +1557,7 @@ class TokenizerManager:
1867
1557
  state.first_token_time = state.last_time = time.time()
1868
1558
  state.last_completion_tokens = completion_tokens
1869
1559
  self.metrics_collector.observe_time_to_first_token(
1870
- state.first_token_time - state.created_time
1560
+ labels, state.first_token_time - state.created_time
1871
1561
  )
1872
1562
  else:
1873
1563
  num_new_tokens = completion_tokens - state.last_completion_tokens
@@ -1875,6 +1565,7 @@ class TokenizerManager:
1875
1565
  new_time = time.time()
1876
1566
  interval = new_time - state.last_time
1877
1567
  self.metrics_collector.observe_inter_token_latency(
1568
+ labels,
1878
1569
  interval,
1879
1570
  num_new_tokens,
1880
1571
  )
@@ -1889,6 +1580,7 @@ class TokenizerManager:
1889
1580
  or state.obj.sampling_params.get("structural_tag", None)
1890
1581
  )
1891
1582
  self.metrics_collector.observe_one_finished_request(
1583
+ labels,
1892
1584
  recv_obj.prompt_tokens[i],
1893
1585
  completion_tokens,
1894
1586
  recv_obj.cached_tokens[i],
@@ -1941,7 +1633,7 @@ class TokenizerManager:
1941
1633
 
1942
1634
  asyncio.create_task(asyncio.to_thread(background_task))
1943
1635
 
1944
- def _handle_abort_req(self, recv_obj):
1636
+ def _handle_abort_req(self, recv_obj: AbortReq):
1945
1637
  if is_health_check_generate_req(recv_obj):
1946
1638
  return
1947
1639
  state = self.rid_to_state[recv_obj.rid]
@@ -2060,11 +1752,15 @@ class TokenizerManager:
2060
1752
  # the next position after the last token in the prompt
2061
1753
  output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
2062
1754
 
2063
- # Throw an error here if output_logprobs is None
2064
- if output_logprobs is None:
1755
+ # Check if output_logprobs is properly populated
1756
+ if (
1757
+ output_logprobs is None
1758
+ or not output_logprobs
1759
+ or len(output_logprobs) == 0
1760
+ ):
2065
1761
  raise RuntimeError(
2066
- f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
2067
- "This usually indicates a problem with the scoring request or the backend output."
1762
+ f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
1763
+ "This indicates token_ids_logprobs were not computed properly for the scoring request."
2068
1764
  )
2069
1765
 
2070
1766
  for logprob, token_id, _ in output_logprobs[0]:
@@ -2089,6 +1785,43 @@ class TokenizerManager:
2089
1785
 
2090
1786
  return scores
2091
1787
 
1788
+ async def watch_load_thread(self):
1789
+ # Only for dp_controller when dp_size > 1
1790
+ if (
1791
+ self.server_args.dp_size == 1
1792
+ or self.server_args.load_balance_method == "round_robin"
1793
+ ):
1794
+ return
1795
+
1796
+ while True:
1797
+ await asyncio.sleep(self.server_args.load_watch_interval)
1798
+ loads = await self.get_load_communicator(GetLoadReqInput())
1799
+ load_udpate_req = WatchLoadUpdateReq(loads=loads)
1800
+ self.send_to_scheduler.send_pyobj(load_udpate_req)
1801
+
1802
+ def _trace_request_start(
1803
+ self,
1804
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
1805
+ created_time: Optional[float] = None,
1806
+ ):
1807
+ if obj.is_single:
1808
+ bootstrap_room = (
1809
+ obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
1810
+ )
1811
+ trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
1812
+ trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
1813
+ else:
1814
+ for i in range(len(obj.rid)):
1815
+ bootstrap_room = (
1816
+ obj.bootstrap_room[i]
1817
+ if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
1818
+ else None
1819
+ )
1820
+ trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
1821
+ trace_slice_start(
1822
+ "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
1823
+ )
1824
+
2092
1825
 
2093
1826
  class ServerStatus(Enum):
2094
1827
  Up = "Up"
@@ -2134,57 +1867,12 @@ class SignalHandler:
2134
1867
 
2135
1868
  def running_phase_sigquit_handler(self, signum=None, frame=None):
2136
1869
  logger.error(
2137
- "Received sigquit from a child process. It usually means the child failed."
1870
+ f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
2138
1871
  )
2139
1872
  self.tokenizer_manager.dump_requests_before_crash()
2140
1873
  kill_process_tree(os.getpid())
2141
1874
 
2142
1875
 
2143
- T = TypeVar("T")
2144
-
2145
-
2146
- class _Communicator(Generic[T]):
2147
- """Note: The communicator now only run up to 1 in-flight request at any time."""
2148
-
2149
- enable_multi_tokenizer = False
2150
-
2151
- def __init__(self, sender, fan_out: int):
2152
- self._sender = sender
2153
- self._fan_out = fan_out
2154
- self._result_event: Optional[asyncio.Event] = None
2155
- self._result_values: Optional[List[T]] = None
2156
- self._ready_queue: Deque[asyncio.Future] = deque()
2157
-
2158
- async def __call__(self, obj):
2159
- ready_event = asyncio.Event()
2160
- if self._result_event is not None or len(self._ready_queue) > 0:
2161
- self._ready_queue.append(ready_event)
2162
- await ready_event.wait()
2163
- assert self._result_event is None
2164
- assert self._result_values is None
2165
-
2166
- if obj:
2167
- if _Communicator.enable_multi_tokenizer:
2168
- obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
2169
- self._sender.send_pyobj(obj)
2170
-
2171
- self._result_event = asyncio.Event()
2172
- self._result_values = []
2173
- await self._result_event.wait()
2174
- result_values = self._result_values
2175
- self._result_event = self._result_values = None
2176
-
2177
- if len(self._ready_queue) > 0:
2178
- self._ready_queue.popleft().set()
2179
-
2180
- return result_values
2181
-
2182
- def handle_recv(self, recv_obj: T):
2183
- self._result_values.append(recv_obj)
2184
- if len(self._result_values) == self._fan_out:
2185
- self._result_event.set()
2186
-
2187
-
2188
1876
  # Note: request abort handling logic
2189
1877
  # We should handle all of the following cases correctly.
2190
1878
  #