sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
44
44
  DecodeTransferQueue,
45
45
  SchedulerDisaggregationDecodeMixin,
46
46
  )
47
+ from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
48
+ DecodeKVCacheOffloadManager,
49
+ )
47
50
  from sglang.srt.disaggregation.prefill import (
48
51
  PrefillBootstrapQueue,
49
52
  SchedulerDisaggregationPrefillMixin,
@@ -57,11 +60,6 @@ from sglang.srt.disaggregation.utils import (
57
60
  )
58
61
  from sglang.srt.distributed import get_pp_group, get_world_group
59
62
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
60
- from sglang.srt.hf_transformers_utils import (
61
- get_processor,
62
- get_tokenizer,
63
- get_tokenizer_from_processor,
64
- )
65
63
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
64
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
65
  from sglang.srt.layers.moe import initialize_moe_config
@@ -72,20 +70,26 @@ from sglang.srt.managers.io_struct import (
72
70
  ClearHiCacheReqInput,
73
71
  ClearHiCacheReqOutput,
74
72
  CloseSessionReqInput,
73
+ DestroyWeightsUpdateGroupReqInput,
75
74
  ExpertDistributionReq,
76
75
  ExpertDistributionReqOutput,
76
+ ExpertDistributionReqType,
77
77
  FlushCacheReqInput,
78
78
  FlushCacheReqOutput,
79
79
  FreezeGCReq,
80
80
  GetInternalStateReq,
81
81
  GetInternalStateReqOutput,
82
+ GetLoadReqInput,
83
+ GetLoadReqOutput,
82
84
  GetWeightsByNameReqInput,
83
85
  HealthCheckOutput,
86
+ InitWeightsSendGroupForRemoteInstanceReqInput,
87
+ InitWeightsSendGroupForRemoteInstanceReqOutput,
84
88
  InitWeightsUpdateGroupReqInput,
85
89
  LoadLoRAAdapterReqInput,
86
90
  LoadLoRAAdapterReqOutput,
87
91
  MultiTokenizerRegisterReq,
88
- MultiTokenizerWarpper,
92
+ MultiTokenizerWrapper,
89
93
  OpenSessionReqInput,
90
94
  OpenSessionReqOutput,
91
95
  ProfileReq,
@@ -93,6 +97,8 @@ from sglang.srt.managers.io_struct import (
93
97
  ResumeMemoryOccupationReqInput,
94
98
  RpcReqInput,
95
99
  RpcReqOutput,
100
+ SendWeightsToRemoteInstanceReqInput,
101
+ SendWeightsToRemoteInstanceReqOutput,
96
102
  SetInternalStateReq,
97
103
  SetInternalStateReqOutput,
98
104
  SlowDownReqInput,
@@ -110,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
110
116
  FINISH_ABORT,
111
117
  MultimodalInputs,
112
118
  Req,
119
+ RequestStage,
113
120
  ScheduleBatch,
114
121
  global_server_args_dict,
115
122
  )
@@ -134,17 +141,28 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
134
141
  from sglang.srt.managers.session_controller import Session
135
142
  from sglang.srt.managers.tp_worker import TpModelWorker
136
143
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
137
- from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
144
+ from sglang.srt.managers.utils import validate_input_length
138
145
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
139
146
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
140
- from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
141
147
  from sglang.srt.mem_cache.radix_cache import RadixCache
142
148
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
143
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
149
+ from sglang.srt.model_executor.forward_batch_info import (
150
+ ForwardBatchOutput,
151
+ ForwardMode,
152
+ PPProxyTensors,
153
+ )
144
154
  from sglang.srt.parser.reasoning_parser import ReasoningParser
145
155
  from sglang.srt.server_args import PortArgs, ServerArgs
146
156
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
147
157
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
158
+ from sglang.srt.tracing.trace import (
159
+ process_tracing_init,
160
+ trace_set_proc_propagate_context,
161
+ trace_set_thread_info,
162
+ trace_slice_batch,
163
+ trace_slice_end,
164
+ trace_slice_start,
165
+ )
148
166
  from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
149
167
  from sglang.srt.utils import (
150
168
  DynamicGradMode,
@@ -155,9 +173,10 @@ from sglang.srt.utils import (
155
173
  freeze_gc,
156
174
  get_available_gpu_memory,
157
175
  get_bool_env_var,
176
+ get_int_env_var,
158
177
  get_zmq_socket,
159
- is_cpu,
160
178
  kill_itself_when_parent_died,
179
+ numa_bind_to_node,
161
180
  point_to_point_pyobj,
162
181
  pyspy_dump_schedulers,
163
182
  require_mlp_sync,
@@ -166,6 +185,11 @@ from sglang.srt.utils import (
166
185
  set_random_seed,
167
186
  suppress_other_loggers,
168
187
  )
188
+ from sglang.srt.utils.hf_transformers_utils import (
189
+ get_processor,
190
+ get_tokenizer,
191
+ get_tokenizer_from_processor,
192
+ )
169
193
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
170
194
 
171
195
  logger = logging.getLogger(__name__)
@@ -174,24 +198,59 @@ logger = logging.getLogger(__name__)
174
198
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
175
199
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
176
200
 
177
- _is_cpu = is_cpu()
178
-
179
201
 
180
202
  @dataclass
181
203
  class GenerationBatchResult:
182
204
  logits_output: Optional[LogitsProcessorOutput]
183
- pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
205
+ pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
184
206
  next_token_ids: Optional[List[int]]
207
+ can_run_cuda_graph: bool
208
+
209
+ # For output processing
185
210
  extend_input_len_per_req: List[int]
186
211
  extend_logprob_start_len_per_req: List[int]
187
- bid: int
188
- can_run_cuda_graph: bool
212
+
213
+ @classmethod
214
+ def from_forward_batch_output(
215
+ cls,
216
+ forward_batch_output: ForwardBatchOutput,
217
+ extend_input_len_per_req: List[int],
218
+ extend_logprob_start_len_per_req: List[int],
219
+ ):
220
+ # TODO(lsyin): remove this workaround logic and try to unify output classes
221
+
222
+ return cls(
223
+ logits_output=forward_batch_output.logits_output,
224
+ pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
225
+ next_token_ids=forward_batch_output.next_token_ids,
226
+ extend_input_len_per_req=extend_input_len_per_req,
227
+ extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
228
+ can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
229
+ )
230
+
231
+ @classmethod
232
+ def from_pp_proxy(
233
+ cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
234
+ ):
235
+ # TODO(lsyin): also simplify this logic
236
+ # Current PP implementation in scheduler is not compatible with ForwardBatchOutput
237
+ # Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
238
+ proxy_dict = next_pp_outputs.tensors
239
+ return cls(
240
+ logits_output=logits_output,
241
+ pp_hidden_states_proxy_tensors=None,
242
+ next_token_ids=next_pp_outputs["next_token_ids"],
243
+ extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
244
+ extend_logprob_start_len_per_req=proxy_dict.get(
245
+ "extend_logprob_start_len_per_req", None
246
+ ),
247
+ can_run_cuda_graph=can_run_cuda_graph,
248
+ )
189
249
 
190
250
 
191
251
  @dataclass
192
252
  class EmbeddingBatchResult:
193
253
  embeddings: torch.Tensor
194
- bid: int
195
254
 
196
255
 
197
256
  class Scheduler(
@@ -213,7 +272,6 @@ class Scheduler(
213
272
  moe_ep_rank: int,
214
273
  pp_rank: int,
215
274
  dp_rank: Optional[int],
216
- dp_balance_meta: Optional[DPBalanceMeta] = None,
217
275
  ):
218
276
  # Parse args
219
277
  self.server_args = server_args
@@ -226,6 +284,13 @@ class Scheduler(
226
284
  self.pp_size = server_args.pp_size
227
285
  self.dp_size = server_args.dp_size
228
286
  self.schedule_policy = server_args.schedule_policy
287
+ self.enable_priority_scheduling = server_args.enable_priority_scheduling
288
+ self.schedule_low_priority_values_first = (
289
+ server_args.schedule_low_priority_values_first
290
+ )
291
+ self.priority_scheduling_preemption_threshold = (
292
+ server_args.priority_scheduling_preemption_threshold
293
+ )
229
294
  self.enable_lora = server_args.enable_lora
230
295
  self.max_loras_per_batch = server_args.max_loras_per_batch
231
296
  self.enable_overlap = not server_args.disable_overlap_schedule
@@ -234,7 +299,10 @@ class Scheduler(
234
299
  self.enable_metrics_for_all_schedulers = (
235
300
  server_args.enable_metrics_for_all_schedulers
236
301
  )
237
- self.enable_kv_cache_events = server_args.kv_events_config is not None
302
+ self.enable_kv_cache_events = bool(
303
+ server_args.kv_events_config and tp_rank == 0
304
+ )
305
+ self.enable_trace = server_args.enable_trace
238
306
  self.stream_interval = server_args.stream_interval
239
307
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
240
308
  server_args.speculative_algorithm
@@ -348,9 +416,39 @@ class Scheduler(
348
416
  target_worker=self.tp_worker,
349
417
  dp_rank=dp_rank,
350
418
  )
419
+ elif self.spec_algorithm.is_standalone():
420
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
421
+
422
+ self.draft_worker = StandaloneWorker(
423
+ gpu_id=gpu_id,
424
+ tp_rank=tp_rank,
425
+ moe_ep_rank=moe_ep_rank,
426
+ server_args=server_args,
427
+ nccl_port=port_args.nccl_port,
428
+ target_worker=self.tp_worker,
429
+ dp_rank=dp_rank,
430
+ )
431
+ elif self.spec_algorithm.is_ngram():
432
+ from sglang.srt.speculative.ngram_worker import NGRAMWorker
433
+
434
+ self.draft_worker = NGRAMWorker(
435
+ gpu_id=gpu_id,
436
+ tp_rank=tp_rank,
437
+ moe_ep_rank=moe_ep_rank,
438
+ server_args=server_args,
439
+ nccl_port=port_args.nccl_port,
440
+ target_worker=self.tp_worker,
441
+ dp_rank=dp_rank,
442
+ )
351
443
  else:
352
444
  self.draft_worker = None
353
445
 
446
+ # Dispatch the model worker
447
+ if self.spec_algorithm.is_none():
448
+ self.model_worker = self.tp_worker
449
+ else:
450
+ self.model_worker = self.draft_worker
451
+
354
452
  # Get token and memory info from the model worker
355
453
  (
356
454
  self.max_total_num_tokens,
@@ -401,7 +499,7 @@ class Scheduler(
401
499
  f"max_prefill_tokens={self.max_prefill_tokens}, "
402
500
  f"max_running_requests={self.max_running_requests}, "
403
501
  f"context_len={self.model_config.context_len}, "
404
- f"available_gpu_mem={avail_mem:.2f} GB"
502
+ f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
405
503
  )
406
504
 
407
505
  # Init memory pool and cache
@@ -458,7 +556,12 @@ class Scheduler(
458
556
  self.schedule_policy,
459
557
  self.tree_cache,
460
558
  self.enable_hierarchical_cache,
559
+ self.enable_priority_scheduling,
560
+ self.schedule_low_priority_values_first,
461
561
  )
562
+ # Enable preemption for priority scheduling.
563
+ self.try_preemption = self.enable_priority_scheduling
564
+
462
565
  assert (
463
566
  server_args.schedule_conservativeness >= 0
464
567
  ), "Invalid schedule_conservativeness"
@@ -488,7 +591,7 @@ class Scheduler(
488
591
  enable=server_args.enable_memory_saver
489
592
  )
490
593
  self.offload_tags = set()
491
- self.init_profier()
594
+ self.init_profiler()
492
595
 
493
596
  self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
494
597
  self.input_blocker = (
@@ -499,8 +602,9 @@ class Scheduler(
499
602
 
500
603
  # Init metrics stats
501
604
  self.init_metrics(tp_rank, pp_rank, dp_rank)
502
- self.init_kv_events(server_args.kv_events_config)
503
- self.init_dp_balance(dp_balance_meta)
605
+
606
+ if self.enable_kv_cache_events:
607
+ self.init_kv_events(server_args.kv_events_config)
504
608
 
505
609
  # Init disaggregation
506
610
  self.disaggregation_mode = DisaggregationMode(
@@ -511,6 +615,9 @@ class Scheduler(
511
615
  if get_bool_env_var("SGLANG_GC_LOG"):
512
616
  configure_gc_logger()
513
617
 
618
+ # Init prefill kv split size when deterministic inference is enabled with various attention backends
619
+ self.init_deterministic_inference_config()
620
+
514
621
  # Init request dispatcher
515
622
  self._request_dispatcher = TypeBasedDispatcher(
516
623
  [
@@ -525,6 +632,15 @@ class Scheduler(
525
632
  (CloseSessionReqInput, self.close_session),
526
633
  (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
527
634
  (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
635
+ (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
636
+ (
637
+ InitWeightsSendGroupForRemoteInstanceReqInput,
638
+ self.init_weights_send_group_for_remote_instance,
639
+ ),
640
+ (
641
+ SendWeightsToRemoteInstanceReqInput,
642
+ self.send_weights_to_remote_instance,
643
+ ),
528
644
  (
529
645
  UpdateWeightsFromDistributedReqInput,
530
646
  self.update_weights_from_distributed,
@@ -543,9 +659,27 @@ class Scheduler(
543
659
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
544
660
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
545
661
  (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
662
+ (GetLoadReqInput, self.get_load),
546
663
  ]
547
664
  )
548
665
 
666
+ def init_deterministic_inference_config(self):
667
+ """Initialize deterministic inference configuration for different attention backends."""
668
+ if not self.server_args.enable_deterministic_inference:
669
+ self.truncation_align_size = None
670
+ return
671
+
672
+ backend_sizes = {
673
+ "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
674
+ "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
675
+ }
676
+ env_var, default_size = backend_sizes.get(
677
+ self.server_args.attention_backend, (None, None)
678
+ )
679
+ self.truncation_align_size = (
680
+ get_int_env_var(env_var, default_size) if env_var else None
681
+ )
682
+
549
683
  def init_tokenizer(self):
550
684
  server_args = self.server_args
551
685
  self.is_generation = self.model_config.is_generation
@@ -617,15 +751,18 @@ class Scheduler(
617
751
  else self.tp_cpu_group
618
752
  ),
619
753
  page_size=self.page_size,
754
+ eviction_policy=server_args.radix_eviction_policy,
620
755
  hicache_ratio=server_args.hicache_ratio,
621
756
  hicache_size=server_args.hicache_size,
622
757
  hicache_write_policy=server_args.hicache_write_policy,
623
758
  hicache_io_backend=server_args.hicache_io_backend,
624
759
  hicache_mem_layout=server_args.hicache_mem_layout,
760
+ enable_metrics=self.enable_metrics,
625
761
  hicache_storage_backend=server_args.hicache_storage_backend,
626
762
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
627
763
  model_name=server_args.served_model_name,
628
764
  storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
765
+ is_eagle=self.spec_algorithm.is_eagle(),
629
766
  )
630
767
  self.tp_worker.register_hicache_layer_transfer_counter(
631
768
  self.tree_cache.cache_controller.layer_done_counter
@@ -641,18 +778,21 @@ class Scheduler(
641
778
  page_size=self.page_size,
642
779
  disable=server_args.disable_radix_cache,
643
780
  )
644
- elif self.enable_lora:
645
- assert (
646
- not self.enable_hierarchical_cache
647
- ), "LoRA radix cache doesn't support hierarchical cache"
648
- assert (
649
- self.schedule_policy == "fcfs"
650
- ), "LoRA radix cache only supports FCFS policy"
651
- self.tree_cache = LoRARadixCache(
781
+ elif server_args.enable_lmcache:
782
+ from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
783
+ LMCRadixCache,
784
+ )
785
+
786
+ self.tree_cache = LMCRadixCache(
652
787
  req_to_token_pool=self.req_to_token_pool,
653
788
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
654
789
  page_size=self.page_size,
655
790
  disable=server_args.disable_radix_cache,
791
+ model_config=self.model_config,
792
+ tp_size=self.tp_size,
793
+ rank=self.tp_rank,
794
+ tp_group=self.tp_group,
795
+ eviction_policy=server_args.radix_eviction_policy,
656
796
  )
657
797
  else:
658
798
  self.tree_cache = RadixCache(
@@ -661,16 +801,36 @@ class Scheduler(
661
801
  page_size=self.page_size,
662
802
  disable=server_args.disable_radix_cache,
663
803
  enable_kv_cache_events=self.enable_kv_cache_events,
804
+ eviction_policy=server_args.radix_eviction_policy,
805
+ is_eagle=self.spec_algorithm.is_eagle(),
664
806
  )
665
807
 
808
+ if (
809
+ server_args.disaggregation_mode == "decode"
810
+ and server_args.disaggregation_decode_enable_offload_kvcache
811
+ ):
812
+ self.decode_offload_manager = DecodeKVCacheOffloadManager(
813
+ req_to_token_pool=self.req_to_token_pool,
814
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
815
+ tp_group=(
816
+ self.attn_tp_cpu_group
817
+ if self.server_args.enable_dp_attention
818
+ else self.tp_cpu_group
819
+ ),
820
+ tree_cache=self.tree_cache,
821
+ server_args=self.server_args,
822
+ )
823
+ else:
824
+ self.decode_offload_manager = None
825
+
666
826
  self.decode_mem_cache_buf_multiplier = (
667
827
  1
668
828
  if self.spec_algorithm.is_none()
669
829
  else (
670
830
  server_args.speculative_num_draft_tokens
671
831
  + (
672
- server_args.speculative_eagle_topk
673
- * server_args.speculative_num_steps
832
+ (server_args.speculative_eagle_topk or 1)
833
+ * (server_args.speculative_num_steps or 1)
674
834
  )
675
835
  )
676
836
  )
@@ -693,7 +853,7 @@ class Scheduler(
693
853
  self.disagg_metadata_buffers = MetadataBuffers(
694
854
  buffer_size,
695
855
  hidden_size=self.model_config.hf_text_config.hidden_size,
696
- dtype=self.model_config.dtype,
856
+ hidden_states_dtype=self.model_config.dtype,
697
857
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
698
858
  )
699
859
 
@@ -713,7 +873,7 @@ class Scheduler(
713
873
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
714
874
  draft_token_to_kv_pool=(
715
875
  None
716
- if self.draft_worker is None
876
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
717
877
  else self.draft_worker.model_runner.token_to_kv_pool
718
878
  ),
719
879
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -742,7 +902,7 @@ class Scheduler(
742
902
  self.disagg_metadata_buffers = MetadataBuffers(
743
903
  buffer_size,
744
904
  hidden_size=self.model_config.hf_text_config.hidden_size,
745
- dtype=self.model_config.dtype,
905
+ hidden_states_dtype=self.model_config.dtype,
746
906
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
747
907
  )
748
908
 
@@ -750,7 +910,7 @@ class Scheduler(
750
910
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
751
911
  draft_token_to_kv_pool=(
752
912
  None
753
- if self.draft_worker is None
913
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
754
914
  else self.draft_worker.model_runner.token_to_kv_pool
755
915
  ),
756
916
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -845,7 +1005,6 @@ class Scheduler(
845
1005
  self.running_mbs = [
846
1006
  ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
847
1007
  ]
848
- bids = [None] * self.pp_size
849
1008
  pp_outputs: Optional[PPProxyTensors] = None
850
1009
  while True:
851
1010
  server_is_idle = True
@@ -866,10 +1025,7 @@ class Scheduler(
866
1025
  # (last rank) send the outputs to the next step
867
1026
  if self.pp_group.is_last_rank:
868
1027
  if self.cur_batch:
869
- next_token_ids, bids[mb_id] = (
870
- result.next_token_ids,
871
- result.bid,
872
- )
1028
+ next_token_ids = result.next_token_ids
873
1029
  if self.cur_batch.return_logprob:
874
1030
  pp_outputs = PPProxyTensors(
875
1031
  {
@@ -917,17 +1073,10 @@ class Scheduler(
917
1073
  logits_output = LogitsProcessorOutput(**logits_output_args)
918
1074
  else:
919
1075
  logits_output = None
920
- output_result = GenerationBatchResult(
1076
+
1077
+ output_result = GenerationBatchResult.from_pp_proxy(
921
1078
  logits_output=logits_output,
922
- pp_hidden_states_proxy_tensors=None,
923
- next_token_ids=next_pp_outputs["next_token_ids"],
924
- extend_input_len_per_req=next_pp_outputs.tensors.get(
925
- "extend_input_len_per_req", None
926
- ),
927
- extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
928
- "extend_logprob_start_len_per_req", None
929
- ),
930
- bid=bids[next_mb_id],
1079
+ next_pp_outputs=next_pp_outputs,
931
1080
  can_run_cuda_graph=result.can_run_cuda_graph,
932
1081
  )
933
1082
  self.process_batch_result(mbs[next_mb_id], output_result)
@@ -935,8 +1084,6 @@ class Scheduler(
935
1084
 
936
1085
  # (not last rank)
937
1086
  if not self.pp_group.is_last_rank:
938
- if self.cur_batch:
939
- bids[mb_id] = result.bid
940
1087
  # carry the outputs to the next stage
941
1088
  # send the outputs from the last round to let the next stage worker run post processing
942
1089
  if pp_outputs:
@@ -958,8 +1105,10 @@ class Scheduler(
958
1105
 
959
1106
  # send out proxy tensors to the next stage
960
1107
  if self.cur_batch:
1108
+ # FIXME(lsyin): remove this assert
1109
+ assert result.pp_hidden_states_proxy_tensors.tensors is not None
961
1110
  self.pp_group.send_tensor_dict(
962
- result.pp_hidden_states_proxy_tensors,
1111
+ result.pp_hidden_states_proxy_tensors.tensors,
963
1112
  all_gather_group=self.attn_tp_group,
964
1113
  )
965
1114
 
@@ -1069,6 +1218,15 @@ class Scheduler(
1069
1218
  self.tp_cpu_group,
1070
1219
  src=self.tp_group.ranks[0],
1071
1220
  )
1221
+
1222
+ if self.enable_trace:
1223
+ for req in recv_reqs:
1224
+ if isinstance(
1225
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1226
+ ):
1227
+ trace_set_proc_propagate_context(req.rid, req.trace_context)
1228
+ trace_slice_start("", req.rid, anonymous=True)
1229
+
1072
1230
  return recv_reqs
1073
1231
 
1074
1232
  def process_input_requests(self, recv_reqs: List):
@@ -1082,27 +1240,13 @@ class Scheduler(
1082
1240
  self.return_health_check_ct += 1
1083
1241
  continue
1084
1242
 
1085
- # If it is a work request, accept or reject the request based on the request queue size.
1086
- if is_work_request(recv_req):
1087
- if len(self.waiting_queue) + 1 > self.max_queued_requests:
1088
- abort_req = AbortReq(
1089
- recv_req.rid,
1090
- finished_reason={
1091
- "type": "abort",
1092
- "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1093
- "message": "The request queue is full.",
1094
- },
1095
- )
1096
- self.send_to_tokenizer.send_pyobj(abort_req)
1097
- continue
1098
-
1099
- # If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
1100
- if isinstance(recv_req, MultiTokenizerWarpper):
1243
+ # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
1244
+ if isinstance(recv_req, MultiTokenizerWrapper):
1101
1245
  worker_id = recv_req.worker_id
1102
1246
  recv_req = recv_req.obj
1103
1247
  output = self._request_dispatcher(recv_req)
1104
1248
  if output is not None:
1105
- output = MultiTokenizerWarpper(worker_id, output)
1249
+ output = MultiTokenizerWrapper(worker_id, output)
1106
1250
  self.send_to_tokenizer.send_pyobj(output)
1107
1251
  continue
1108
1252
 
@@ -1114,12 +1258,20 @@ class Scheduler(
1114
1258
  else:
1115
1259
  self.send_to_tokenizer.send_pyobj(output)
1116
1260
 
1261
+ def init_req_max_new_tokens(self, req):
1262
+ req.sampling_params.max_new_tokens = min(
1263
+ (
1264
+ req.sampling_params.max_new_tokens
1265
+ if req.sampling_params.max_new_tokens is not None
1266
+ else 1 << 30
1267
+ ),
1268
+ self.max_req_len - len(req.origin_input_ids) - 1,
1269
+ )
1270
+
1117
1271
  def handle_generate_request(
1118
1272
  self,
1119
1273
  recv_req: TokenizedGenerateReqInput,
1120
1274
  ):
1121
- self.maybe_update_dp_balance_data(recv_req)
1122
-
1123
1275
  # Create a new request
1124
1276
  if (
1125
1277
  recv_req.session_params is None
@@ -1153,8 +1305,13 @@ class Scheduler(
1153
1305
  bootstrap_host=recv_req.bootstrap_host,
1154
1306
  bootstrap_port=recv_req.bootstrap_port,
1155
1307
  bootstrap_room=recv_req.bootstrap_room,
1308
+ disagg_mode=self.disaggregation_mode,
1156
1309
  data_parallel_rank=recv_req.data_parallel_rank,
1157
1310
  vocab_size=self.model_config.vocab_size,
1311
+ priority=recv_req.priority,
1312
+ metrics_collector=(
1313
+ self.metrics_collector if self.enable_metrics else None
1314
+ ),
1158
1315
  )
1159
1316
  req.tokenizer = self.tokenizer
1160
1317
 
@@ -1177,6 +1334,7 @@ class Scheduler(
1177
1334
  req.set_finish_with_abort(
1178
1335
  f"Invalid request: session id {recv_req.session_params.id} does not exist"
1179
1336
  )
1337
+ self.init_req_max_new_tokens(req)
1180
1338
  self._add_request_to_queue(req)
1181
1339
  return
1182
1340
  else:
@@ -1184,6 +1342,7 @@ class Scheduler(
1184
1342
  session = self.sessions[recv_req.session_params.id]
1185
1343
  req = session.create_req(recv_req, self.tokenizer)
1186
1344
  if isinstance(req.finished_reason, FINISH_ABORT):
1345
+ self.init_req_max_new_tokens(req)
1187
1346
  self._add_request_to_queue(req)
1188
1347
  return
1189
1348
 
@@ -1203,9 +1362,13 @@ class Scheduler(
1203
1362
  f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
1204
1363
  )
1205
1364
  )
1365
+ self.init_req_max_new_tokens(req)
1206
1366
  self._add_request_to_queue(req)
1207
1367
  return
1208
1368
 
1369
+ # initialize before returning
1370
+ self.init_req_max_new_tokens(req)
1371
+
1209
1372
  # Validate prompt length
1210
1373
  error_msg = validate_input_length(
1211
1374
  req,
@@ -1220,26 +1383,25 @@ class Scheduler(
1220
1383
  # Copy more attributes
1221
1384
  if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1222
1385
  # By default, only return the logprobs for output tokens
1223
- req.logprob_start_len = len(req.origin_input_ids) - 1
1386
+ # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
1387
+ # to skip input logprob computation entirely
1388
+ if req.is_prefill_only:
1389
+ req.logprob_start_len = len(req.origin_input_ids)
1390
+ else:
1391
+ # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
1392
+ req.logprob_start_len = len(req.origin_input_ids) - 1
1224
1393
  else:
1225
1394
  req.logprob_start_len = recv_req.logprob_start_len
1226
1395
 
1227
- if req.logprob_start_len >= len(req.origin_input_ids):
1396
+ if not req.is_prefill_only and req.logprob_start_len >= len(
1397
+ req.origin_input_ids
1398
+ ):
1228
1399
  error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1229
1400
  req.logprob_start_len = len(req.origin_input_ids) - 1
1230
1401
  req.set_finish_with_abort(error_msg)
1231
1402
  self._add_request_to_queue(req)
1232
1403
  return
1233
1404
 
1234
- req.sampling_params.max_new_tokens = min(
1235
- (
1236
- req.sampling_params.max_new_tokens
1237
- if req.sampling_params.max_new_tokens is not None
1238
- else 1 << 30
1239
- ),
1240
- self.max_req_len - len(req.origin_input_ids) - 1,
1241
- )
1242
-
1243
1405
  # Init grammar cache for this request
1244
1406
  add_to_grammar_queue = False
1245
1407
  if (
@@ -1270,7 +1432,6 @@ class Scheduler(
1270
1432
  req.set_finish_with_abort(error_msg)
1271
1433
 
1272
1434
  if add_to_grammar_queue:
1273
- req.queue_time_start = time.perf_counter()
1274
1435
  self.grammar_queue.append(req)
1275
1436
  else:
1276
1437
  self._add_request_to_queue(req)
@@ -1286,19 +1447,6 @@ class Scheduler(
1286
1447
  for tokenized_req in recv_req:
1287
1448
  self.handle_generate_request(tokenized_req)
1288
1449
 
1289
- def _add_request_to_queue(self, req: Req):
1290
- req.queue_time_start = time.perf_counter()
1291
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1292
- self._prefetch_kvcache(req)
1293
- self.disagg_prefill_bootstrap_queue.add(
1294
- req, self.model_config.num_key_value_heads
1295
- )
1296
- elif self.disaggregation_mode == DisaggregationMode.DECODE:
1297
- self.disagg_decode_prealloc_queue.add(req)
1298
- else:
1299
- self._prefetch_kvcache(req)
1300
- self.waiting_queue.append(req)
1301
-
1302
1450
  def _prefetch_kvcache(self, req: Req):
1303
1451
  if self.enable_hicache_storage:
1304
1452
  req.init_next_round_input(self.tree_cache)
@@ -1312,16 +1460,87 @@ class Scheduler(
1312
1460
  req.rid, req.last_host_node, new_input_tokens, last_hash
1313
1461
  )
1314
1462
 
1315
- def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1316
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1317
- self.disagg_prefill_bootstrap_queue.extend(
1318
- reqs, self.model_config.num_key_value_heads
1463
+ def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
1464
+ if self.disaggregation_mode == DisaggregationMode.NULL:
1465
+ self._set_or_validate_priority(req)
1466
+ if self._abort_on_queued_limit(req):
1467
+ return
1468
+ self._prefetch_kvcache(req)
1469
+ self.waiting_queue.append(req)
1470
+ req.time_stats.wait_queue_entry_time = time.perf_counter()
1471
+ trace_slice_end("process req", req.rid, auto_next_anon=True)
1472
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
1473
+ self._prefetch_kvcache(req)
1474
+ self.disagg_prefill_bootstrap_queue.add(
1475
+ req, self.model_config.num_key_value_heads
1319
1476
  )
1477
+ req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
1320
1478
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1321
- # If this is a decode server, we put the request to the decode pending prealloc queue
1322
- self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
1479
+ self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
1480
+ if not is_retracted:
1481
+ req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
1323
1482
  else:
1324
- self.waiting_queue.extend(reqs)
1483
+ raise ValueError(f"Invalid {self.disaggregation_mode=}")
1484
+
1485
+ def _set_or_validate_priority(self, req: Req):
1486
+ """Set the default priority value, or abort the request based on the priority scheduling mode."""
1487
+ if self.enable_priority_scheduling and req.priority is None:
1488
+ if self.schedule_low_priority_values_first:
1489
+ req.priority = sys.maxsize
1490
+ else:
1491
+ req.priority = -sys.maxsize - 1
1492
+ elif not self.enable_priority_scheduling and req.priority is not None:
1493
+ abort_req = AbortReq(
1494
+ finished_reason={
1495
+ "type": "abort",
1496
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1497
+ "message": "Using priority is disabled for this server. Please send a new request without a priority.",
1498
+ },
1499
+ rid=req.rid,
1500
+ )
1501
+ self.send_to_tokenizer.send_pyobj(abort_req)
1502
+
1503
+ def _abort_on_queued_limit(self, recv_req: Req) -> bool:
1504
+ """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
1505
+ if (
1506
+ self.max_queued_requests is None
1507
+ or len(self.waiting_queue) + 1 <= self.max_queued_requests
1508
+ ):
1509
+ return False
1510
+
1511
+ # Reject the incoming request by default.
1512
+ req_to_abort = recv_req
1513
+ message = "The request queue is full."
1514
+ if self.enable_priority_scheduling:
1515
+ # With priority scheduling, consider aboritng an existing request based on the priority.
1516
+ # direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
1517
+ # max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
1518
+ # Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
1519
+ direction = 1 if self.schedule_low_priority_values_first else -1
1520
+ key_fn = lambda item: (
1521
+ direction * item[1].priority,
1522
+ item[1].time_stats.wait_queue_entry_time,
1523
+ )
1524
+ idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
1525
+ abort_existing_req = (
1526
+ direction * recv_req.priority < direction * candidate_req.priority
1527
+ )
1528
+ if abort_existing_req:
1529
+ self.waiting_queue.pop(idx)
1530
+ req_to_abort = candidate_req
1531
+ message = "The request is aborted by a higher priority request."
1532
+
1533
+ self.send_to_tokenizer.send_pyobj(
1534
+ AbortReq(
1535
+ finished_reason={
1536
+ "type": "abort",
1537
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1538
+ "message": message,
1539
+ },
1540
+ rid=req_to_abort.rid,
1541
+ )
1542
+ )
1543
+ return req_to_abort.rid == recv_req.rid
1325
1544
 
1326
1545
  def handle_embedding_request(
1327
1546
  self,
@@ -1333,6 +1552,7 @@ class Scheduler(
1333
1552
  recv_req.input_ids,
1334
1553
  recv_req.sampling_params,
1335
1554
  token_type_ids=recv_req.token_type_ids,
1555
+ priority=recv_req.priority,
1336
1556
  )
1337
1557
  req.tokenizer = self.tokenizer
1338
1558
 
@@ -1409,9 +1629,11 @@ class Scheduler(
1409
1629
  _, _, available_size, evictable_size = self._get_token_info()
1410
1630
  protected_size = self.tree_cache.protected_size()
1411
1631
  memory_leak = (available_size + evictable_size) != (
1632
+ # self.max_total_num_tokens
1633
+ # if not self.enable_hierarchical_cache
1634
+ # else self.max_total_num_tokens - protected_size
1412
1635
  self.max_total_num_tokens
1413
- if not self.enable_hierarchical_cache
1414
- else self.max_total_num_tokens - protected_size
1636
+ - protected_size
1415
1637
  )
1416
1638
  token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
1417
1639
 
@@ -1462,6 +1684,20 @@ class Scheduler(
1462
1684
  self.stats.gen_throughput = 0
1463
1685
  self.stats.num_queue_reqs = len(self.waiting_queue)
1464
1686
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1687
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1688
+ self.stats.num_prefill_prealloc_queue_reqs = len(
1689
+ self.disagg_prefill_bootstrap_queue.queue
1690
+ )
1691
+ self.stats.num_prefill_inflight_queue_reqs = len(
1692
+ self.disagg_prefill_inflight_queue
1693
+ )
1694
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1695
+ self.stats.num_decode_prealloc_queue_reqs = len(
1696
+ self.disagg_decode_prealloc_queue.queue
1697
+ )
1698
+ self.stats.num_decode_transfer_queue_reqs = len(
1699
+ self.disagg_decode_transfer_queue.queue
1700
+ )
1465
1701
  self.metrics_collector.log_stats(self.stats)
1466
1702
  self._publish_kv_events()
1467
1703
 
@@ -1509,7 +1745,12 @@ class Scheduler(
1509
1745
  chunked_req_to_exclude.add(self.chunked_req)
1510
1746
  self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1511
1747
  # chunked request keeps its rid but will get a new req_pool_idx
1512
- self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1748
+ if self.tp_worker.worker.model_runner.is_hybrid_gdn:
1749
+ self.req_to_token_pool.free(
1750
+ self.chunked_req.req_pool_idx, free_mamba_cache=False
1751
+ )
1752
+ else:
1753
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1513
1754
  if self.last_batch and self.last_batch.forward_mode.is_extend():
1514
1755
  if self.last_batch.chunked_req is not None:
1515
1756
  # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
@@ -1556,7 +1797,6 @@ class Scheduler(
1556
1797
 
1557
1798
  # Handle DP attention
1558
1799
  if need_dp_attn_preparation:
1559
- self.maybe_handle_dp_balance_data()
1560
1800
  ret = self.prepare_mlp_sync_batch(ret)
1561
1801
 
1562
1802
  return ret
@@ -1572,6 +1812,10 @@ class Scheduler(
1572
1812
  if self.grammar_queue:
1573
1813
  self.move_ready_grammar_requests()
1574
1814
 
1815
+ if self.try_preemption:
1816
+ # Reset batch_is_full to try preemption with a prefill adder.
1817
+ self.running_batch.batch_is_full = False
1818
+
1575
1819
  # Handle the cases where prefill is not allowed
1576
1820
  if (
1577
1821
  self.running_batch.batch_is_full or len(self.waiting_queue) == 0
@@ -1584,7 +1828,11 @@ class Scheduler(
1584
1828
  # as the space for the chunked request has just been released.
1585
1829
  # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
1586
1830
  # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1587
- if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
1831
+ if (
1832
+ self.get_num_allocatable_reqs(running_bs) <= 0
1833
+ and not self.chunked_req
1834
+ and not self.try_preemption
1835
+ ):
1588
1836
  self.running_batch.batch_is_full = True
1589
1837
  return None
1590
1838
 
@@ -1604,6 +1852,7 @@ class Scheduler(
1604
1852
  self.max_prefill_tokens,
1605
1853
  self.chunked_prefill_size,
1606
1854
  running_bs if self.is_mixed_chunk else 0,
1855
+ self.priority_scheduling_preemption_threshold,
1607
1856
  )
1608
1857
 
1609
1858
  if self.chunked_req is not None:
@@ -1624,15 +1873,19 @@ class Scheduler(
1624
1873
  self.running_batch.batch_is_full = True
1625
1874
  break
1626
1875
 
1876
+ running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
1627
1877
  if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
1628
1878
  self.running_batch.batch_is_full = True
1629
- break
1630
-
1631
1879
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1632
1880
  # In prefill mode, prealloc queue and transfer queue can also take memory,
1633
1881
  # so we need to check if the available size for the actual available size.
1634
1882
  if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
1635
1883
  self.running_batch.batch_is_full = True
1884
+
1885
+ if self.running_batch.batch_is_full:
1886
+ if not self.try_preemption:
1887
+ break
1888
+ if not adder.preempt_to_schedule(req, self.server_args):
1636
1889
  break
1637
1890
 
1638
1891
  if self.enable_hicache_storage:
@@ -1642,7 +1895,11 @@ class Scheduler(
1642
1895
  continue
1643
1896
 
1644
1897
  req.init_next_round_input(self.tree_cache)
1645
- res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1898
+ res = adder.add_one_req(
1899
+ req,
1900
+ has_chunked_req=(self.chunked_req is not None),
1901
+ truncation_align_size=self.truncation_align_size,
1902
+ )
1646
1903
 
1647
1904
  if res != AddReqResult.CONTINUE:
1648
1905
  if res == AddReqResult.NO_TOKEN:
@@ -1663,11 +1920,14 @@ class Scheduler(
1663
1920
  if self.enable_metrics:
1664
1921
  # only record queue time when enable_metrics is True to avoid overhead
1665
1922
  for req in can_run_list:
1666
- req.queue_time_end = time.perf_counter()
1923
+ req.add_latency(RequestStage.PREFILL_WAITING)
1667
1924
 
1668
1925
  self.waiting_queue = [
1669
1926
  x for x in self.waiting_queue if x not in set(can_run_list)
1670
1927
  ]
1928
+ if adder.preempt_list:
1929
+ for req in adder.preempt_list:
1930
+ self._add_request_to_queue(req)
1671
1931
 
1672
1932
  if adder.new_chunked_req is not None:
1673
1933
  assert self.chunked_req is None
@@ -1678,7 +1938,16 @@ class Scheduler(
1678
1938
 
1679
1939
  # Print stats
1680
1940
  if self.current_scheduler_metrics_enabled():
1681
- self.log_prefill_stats(adder, can_run_list, running_bs)
1941
+ self.log_prefill_stats(adder, can_run_list, running_bs, 0)
1942
+
1943
+ for req in can_run_list:
1944
+ if req.time_stats.forward_entry_time == 0:
1945
+ # Avoid update chunked request many times
1946
+ req.time_stats.forward_entry_time = time.perf_counter()
1947
+ if self.enable_metrics:
1948
+ self.metrics_collector.observe_queue_time(
1949
+ req.time_stats.get_queueing_time(),
1950
+ )
1682
1951
 
1683
1952
  # Create a new batch
1684
1953
  new_batch = ScheduleBatch.init_new(
@@ -1733,19 +2002,25 @@ class Scheduler(
1733
2002
  TEST_RETRACT and batch.batch_size() > 10
1734
2003
  ):
1735
2004
  old_ratio = self.new_token_ratio
1736
-
1737
- retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1738
- num_retracted_reqs = len(retracted_reqs)
2005
+ retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
2006
+ self.server_args
2007
+ )
2008
+ self.num_retracted_reqs = len(retracted_reqs)
1739
2009
  self.new_token_ratio = new_token_ratio
2010
+ for req in reqs_to_abort:
2011
+ self.send_to_tokenizer.send_pyobj(
2012
+ AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
2013
+ )
1740
2014
 
1741
2015
  logger.info(
1742
2016
  "KV cache pool is full. Retract requests. "
1743
- f"#retracted_reqs: {num_retracted_reqs}, "
1744
- f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
2017
+ f"#retracted_reqs: {len(retracted_reqs)}, "
2018
+ f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
2019
+ f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
1745
2020
  )
1746
2021
 
1747
- self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1748
- self.total_retracted_reqs += num_retracted_reqs
2022
+ for req in retracted_reqs:
2023
+ self._add_request_to_queue(req, is_retracted=True)
1749
2024
  else:
1750
2025
  self.new_token_ratio = max(
1751
2026
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -1773,37 +2048,25 @@ class Scheduler(
1773
2048
 
1774
2049
  # Run forward
1775
2050
  if self.is_generation:
2051
+
2052
+ batch_or_worker_batch = batch
2053
+
1776
2054
  if self.spec_algorithm.is_none():
1777
- model_worker_batch = batch.get_model_worker_batch()
2055
+ # FIXME(lsyin): remove this if and finally unify the abstraction
2056
+ batch_or_worker_batch = batch.get_model_worker_batch()
1778
2057
 
1779
- # update the consumer index of hicache to the running batch
1780
- self.tp_worker.set_hicache_consumer(
1781
- model_worker_batch.hicache_consumer_index
2058
+ forward_batch_output = self.model_worker.forward_batch_generation(
2059
+ batch_or_worker_batch
2060
+ )
2061
+
2062
+ if not self.spec_algorithm.is_none():
2063
+ # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
2064
+ self.udpate_spec_metrics(
2065
+ batch.batch_size(), forward_batch_output.num_accepted_tokens
1782
2066
  )
1783
- if self.pp_group.is_last_rank:
1784
- logits_output, next_token_ids, can_run_cuda_graph = (
1785
- self.tp_worker.forward_batch_generation(model_worker_batch)
1786
- )
1787
- else:
1788
- pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1789
- self.tp_worker.forward_batch_generation(model_worker_batch)
1790
- )
1791
- bid = model_worker_batch.bid
1792
- else:
1793
- (
1794
- logits_output,
1795
- next_token_ids,
1796
- bid,
1797
- num_accepted_tokens,
1798
- can_run_cuda_graph,
1799
- ) = self.draft_worker.forward_batch_speculative_generation(batch)
1800
- bs = batch.batch_size()
1801
- self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
1802
- self.spec_num_total_forward_ct += bs
1803
- self.num_generated_tokens += num_accepted_tokens
1804
-
1805
- if self.pp_group.is_last_rank:
1806
- batch.output_ids = next_token_ids
2067
+
2068
+ # update batch's output ids
2069
+ batch.output_ids = forward_batch_output.next_token_ids
1807
2070
 
1808
2071
  # These 2 values are needed for processing the output, but the values can be
1809
2072
  # modified by overlap schedule. So we have to copy them here so that
@@ -1812,6 +2075,7 @@ class Scheduler(
1812
2075
  extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1813
2076
  else:
1814
2077
  extend_input_len_per_req = None
2078
+
1815
2079
  if batch.return_logprob:
1816
2080
  extend_logprob_start_len_per_req = [
1817
2081
  req.extend_logprob_start_len for req in batch.reqs
@@ -1819,25 +2083,15 @@ class Scheduler(
1819
2083
  else:
1820
2084
  extend_logprob_start_len_per_req = None
1821
2085
 
1822
- ret = GenerationBatchResult(
1823
- logits_output=logits_output if self.pp_group.is_last_rank else None,
1824
- pp_hidden_states_proxy_tensors=(
1825
- pp_hidden_states_proxy_tensors
1826
- if not self.pp_group.is_last_rank
1827
- else None
1828
- ),
1829
- next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
2086
+ return GenerationBatchResult.from_forward_batch_output(
2087
+ forward_batch_output=forward_batch_output,
1830
2088
  extend_input_len_per_req=extend_input_len_per_req,
1831
2089
  extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1832
- bid=bid,
1833
- can_run_cuda_graph=can_run_cuda_graph,
1834
2090
  )
1835
2091
  else: # embedding or reward model
1836
2092
  model_worker_batch = batch.get_model_worker_batch()
1837
2093
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1838
- ret = EmbeddingBatchResult(
1839
- embeddings=embeddings, bid=model_worker_batch.bid
1840
- )
2094
+ ret = EmbeddingBatchResult(embeddings=embeddings)
1841
2095
  return ret
1842
2096
 
1843
2097
  def process_batch_result(
@@ -1848,8 +2102,14 @@ class Scheduler(
1848
2102
  ):
1849
2103
  if batch.forward_mode.is_decode():
1850
2104
  self.process_batch_result_decode(batch, result, launch_done)
2105
+ if self.enable_trace:
2106
+ trace_slice_batch("decode loop", batch.reqs)
2107
+
1851
2108
  elif batch.forward_mode.is_extend():
1852
2109
  self.process_batch_result_prefill(batch, result, launch_done)
2110
+ if self.enable_trace:
2111
+ trace_slice_batch("prefill", batch.reqs)
2112
+
1853
2113
  elif batch.forward_mode.is_idle():
1854
2114
  if self.enable_overlap:
1855
2115
  self.tp_worker.resolve_last_batch_result(launch_done)
@@ -2008,12 +2268,13 @@ class Scheduler(
2008
2268
  if req.finished(): # It is aborted by AbortReq
2009
2269
  num_ready_reqs += 1
2010
2270
  continue
2271
+
2011
2272
  req.grammar = req.grammar.result(timeout=0.03)
2012
2273
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2013
2274
  if req.grammar is INVALID_GRAMMAR_OBJ:
2014
- req.set_finish_with_abort(
2015
- f"Invalid grammar request: {req.grammar_key=}"
2016
- )
2275
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2276
+ req.set_finish_with_abort(error_msg)
2277
+
2017
2278
  num_ready_reqs += 1
2018
2279
  except futures._base.TimeoutError:
2019
2280
  req.grammar_wait_ct += 1
@@ -2045,9 +2306,8 @@ class Scheduler(
2045
2306
  req.grammar = req.grammar.result()
2046
2307
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2047
2308
  if req.grammar is INVALID_GRAMMAR_OBJ:
2048
- req.set_finish_with_abort(
2049
- f"Invalid grammar request: {req.grammar_key=}"
2050
- )
2309
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2310
+ req.set_finish_with_abort(error_msg)
2051
2311
  else:
2052
2312
  num_ready_reqs_max = num_ready_reqs
2053
2313
  num_timeout_reqs_max = num_timeout_reqs
@@ -2055,12 +2315,14 @@ class Scheduler(
2055
2315
  for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
2056
2316
  req = self.grammar_queue[i]
2057
2317
  req.grammar.cancel()
2318
+ self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2058
2319
  error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
2059
2320
  req.set_finish_with_abort(error_msg)
2060
- self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2321
+
2061
2322
  num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2062
2323
 
2063
- self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2324
+ for req in self.grammar_queue[:num_ready_reqs]:
2325
+ self._add_request_to_queue(req)
2064
2326
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
2065
2327
 
2066
2328
  def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
@@ -2152,9 +2414,8 @@ class Scheduler(
2152
2414
  self.req_to_token_pool.clear()
2153
2415
  self.token_to_kv_pool_allocator.clear()
2154
2416
 
2155
- if not self.spec_algorithm.is_none():
2156
- self.draft_worker.model_runner.req_to_token_pool.clear()
2157
- self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2417
+ if self.draft_worker:
2418
+ self.draft_worker.clear_cache_pool()
2158
2419
 
2159
2420
  self.num_generated_tokens = 0
2160
2421
  self.forward_ct_decode = 0
@@ -2174,39 +2435,50 @@ class Scheduler(
2174
2435
  if_success = False
2175
2436
  return if_success
2176
2437
 
2177
- def get_load(self):
2438
+ def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
2178
2439
  # TODO(lsyin): use dynamically maintained num_waiting_tokens
2440
+
2179
2441
  if self.is_hybrid:
2180
- load_full = (
2442
+ num_tokens_full = (
2181
2443
  self.full_tokens_per_layer
2182
2444
  - self.token_to_kv_pool_allocator.full_available_size()
2183
2445
  - self.tree_cache.full_evictable_size()
2184
2446
  )
2185
- load_swa = (
2447
+ num_tokens_swa = (
2186
2448
  self.swa_tokens_per_layer
2187
2449
  - self.token_to_kv_pool_allocator.swa_available_size()
2188
2450
  - self.tree_cache.swa_evictable_size()
2189
2451
  )
2190
- load = max(load_full, load_swa)
2452
+ num_tokens = max(num_tokens_full, num_tokens_swa)
2191
2453
  else:
2192
- load = (
2454
+ num_tokens = (
2193
2455
  self.max_total_num_tokens
2194
2456
  - self.token_to_kv_pool_allocator.available_size()
2195
2457
  - self.tree_cache.evictable_size()
2196
2458
  )
2197
- load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
2459
+
2460
+ # Tokens in waiting queue, bootstrap queue, prealloc queue
2461
+ num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
2462
+ num_waiting_reqs = len(self.waiting_queue)
2198
2463
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2199
- load += sum(
2464
+ num_tokens += sum(
2200
2465
  len(req.origin_input_ids)
2201
2466
  for req in self.disagg_prefill_bootstrap_queue.queue
2202
2467
  )
2468
+ num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
2203
2469
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
2204
- load += sum(
2470
+ num_tokens += sum(
2205
2471
  len(req.req.origin_input_ids)
2206
2472
  for req in self.disagg_decode_prealloc_queue.queue
2207
2473
  )
2474
+ num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
2208
2475
 
2209
- return load
2476
+ return GetLoadReqOutput(
2477
+ dp_rank=self.dp_rank,
2478
+ num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
2479
+ num_waiting_reqs=num_waiting_reqs,
2480
+ num_tokens=num_tokens,
2481
+ )
2210
2482
 
2211
2483
  def get_internal_state(self, recv_req: GetInternalStateReq):
2212
2484
  ret = dict(global_server_args_dict)
@@ -2221,10 +2493,9 @@ class Scheduler(
2221
2493
  "token_capacity": int(self.max_total_num_tokens),
2222
2494
  }
2223
2495
 
2224
- if not _is_cpu:
2225
- ret["memory_usage"]["cuda_graph"] = round(
2226
- self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
2227
- )
2496
+ ret["memory_usage"]["graph"] = round(
2497
+ self.tp_worker.worker.model_runner.graph_mem_usage, 2
2498
+ )
2228
2499
 
2229
2500
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2230
2501
  ret["avg_spec_accept_length"] = (
@@ -2233,8 +2504,6 @@ class Scheduler(
2233
2504
  if RECORD_STEP_TIME:
2234
2505
  ret["step_time_dict"] = self.step_time_dict
2235
2506
 
2236
- ret["load"] = self.get_load()
2237
-
2238
2507
  return GetInternalStateReqOutput(internal_state=ret)
2239
2508
 
2240
2509
  def set_internal_state(self, recv_req: SetInternalStateReq):
@@ -2310,7 +2579,7 @@ class Scheduler(
2310
2579
  if self.enable_hicache_storage:
2311
2580
  # to release prefetch events associated with the request
2312
2581
  self.tree_cache.release_aborted_request(req.rid)
2313
- self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2582
+ self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
2314
2583
  # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2315
2584
  if self.disaggregation_mode == DisaggregationMode.DECODE:
2316
2585
  self.tree_cache.cache_finished_req(req)
@@ -2331,31 +2600,31 @@ class Scheduler(
2331
2600
  # Delete requests not in the waiting queue when PD disaggregation is enabled
2332
2601
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2333
2602
  # Abort requests that have not yet been bootstrapped
2334
- for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
2335
- logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2603
+ for req in self.disagg_prefill_bootstrap_queue.queue:
2336
2604
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2605
+ logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2337
2606
  if hasattr(req.disagg_kv_sender, "abort"):
2338
2607
  req.disagg_kv_sender.abort()
2339
2608
 
2340
2609
  # Abort in-flight requests
2341
- for i, req in enumerate(self.disagg_prefill_inflight_queue):
2342
- logger.debug(f"Abort inflight queue request. {req.rid=}")
2610
+ for req in self.disagg_prefill_inflight_queue:
2343
2611
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2612
+ logger.debug(f"Abort inflight queue request. {req.rid=}")
2344
2613
  if hasattr(req.disagg_kv_sender, "abort"):
2345
2614
  req.disagg_kv_sender.abort()
2346
2615
 
2347
2616
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
2348
2617
  # Abort requests that have not yet finished preallocation
2349
- for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
2350
- logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2618
+ for decode_req in self.disagg_decode_prealloc_queue.queue:
2351
2619
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2620
+ logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2352
2621
  if hasattr(decode_req.kv_receiver, "abort"):
2353
2622
  decode_req.kv_receiver.abort()
2354
2623
 
2355
2624
  # Abort requests waiting for kvcache to release tree cache
2356
- for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
2357
- logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2625
+ for decode_req in self.disagg_decode_transfer_queue.queue:
2358
2626
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2627
+ logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2359
2628
  if hasattr(decode_req.kv_receiver, "abort"):
2360
2629
  decode_req.kv_receiver.abort()
2361
2630
 
@@ -2398,6 +2667,22 @@ class Scheduler(
2398
2667
  self.send_to_detokenizer.send_pyobj(recv_req)
2399
2668
  return recv_req
2400
2669
 
2670
+ def init_weights_send_group_for_remote_instance(
2671
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
2672
+ ):
2673
+ """Init the seed and client instance communication group."""
2674
+ success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
2675
+ recv_req
2676
+ )
2677
+ return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
2678
+
2679
+ def send_weights_to_remote_instance(
2680
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
2681
+ ):
2682
+ """Send the seed instance weights to the destination instance."""
2683
+ success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
2684
+ return SendWeightsToRemoteInstanceReqOutput(success, message)
2685
+
2401
2686
  def slow_down(self, recv_req: SlowDownReqInput):
2402
2687
  t = recv_req.forward_sleep_time
2403
2688
  if t is not None and t <= 0:
@@ -2406,11 +2691,12 @@ class Scheduler(
2406
2691
  return SlowDownReqOutput()
2407
2692
 
2408
2693
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2409
- if recv_req == ExpertDistributionReq.START_RECORD:
2694
+ action = recv_req.action
2695
+ if action == ExpertDistributionReqType.START_RECORD:
2410
2696
  get_global_expert_distribution_recorder().start_record()
2411
- elif recv_req == ExpertDistributionReq.STOP_RECORD:
2697
+ elif action == ExpertDistributionReqType.STOP_RECORD:
2412
2698
  get_global_expert_distribution_recorder().stop_record()
2413
- elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2699
+ elif action == ExpertDistributionReqType.DUMP_RECORD:
2414
2700
  get_global_expert_distribution_recorder().dump_record()
2415
2701
  else:
2416
2702
  raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
@@ -2493,7 +2779,8 @@ class IdleSleeper:
2493
2779
 
2494
2780
 
2495
2781
  def is_health_check_generate_req(recv_req):
2496
- return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2782
+ rid = getattr(recv_req, "rid", None)
2783
+ return rid is not None and rid.startswith("HEALTH_CHECK")
2497
2784
 
2498
2785
 
2499
2786
  def is_work_request(recv_req):
@@ -2517,10 +2804,12 @@ def run_scheduler_process(
2517
2804
  pp_rank: int,
2518
2805
  dp_rank: Optional[int],
2519
2806
  pipe_writer,
2520
- balance_meta: Optional[DPBalanceMeta] = None,
2521
2807
  ):
2522
- # Generate the prefix
2808
+ # Generate the logger prefix
2523
2809
  prefix = ""
2810
+ if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2811
+ # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2812
+ dp_rank = int(os.environ["SGLANG_DP_RANK"])
2524
2813
  if dp_rank is not None:
2525
2814
  prefix += f" DP{dp_rank}"
2526
2815
  if server_args.tp_size > 1:
@@ -2536,10 +2825,6 @@ def run_scheduler_process(
2536
2825
  kill_itself_when_parent_died()
2537
2826
  parent_process = psutil.Process().parent()
2538
2827
 
2539
- # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2540
- if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2541
- dp_rank = int(os.environ["SGLANG_DP_RANK"])
2542
-
2543
2828
  # Configure the logger
2544
2829
  configure_logger(server_args, prefix=prefix)
2545
2830
  suppress_other_loggers()
@@ -2547,6 +2832,15 @@ def run_scheduler_process(
2547
2832
  # Set cpu affinity to this gpu process
2548
2833
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2549
2834
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2835
+ if (numa_node := server_args.numa_node) is not None:
2836
+ numa_bind_to_node(numa_node[gpu_id])
2837
+
2838
+ # Set up tracing
2839
+ if server_args.enable_trace:
2840
+ process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2841
+ if server_args.disaggregation_mode == "null":
2842
+ thread_label = "Scheduler"
2843
+ trace_set_thread_info(thread_label, tp_rank, dp_rank)
2550
2844
 
2551
2845
  # Create a scheduler and run the event loop
2552
2846
  try:
@@ -2558,7 +2852,6 @@ def run_scheduler_process(
2558
2852
  moe_ep_rank,
2559
2853
  pp_rank,
2560
2854
  dp_rank,
2561
- dp_balance_meta=balance_meta,
2562
2855
  )
2563
2856
  pipe_writer.send(
2564
2857
  {