sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
44
44
  DecodeTransferQueue,
45
45
  SchedulerDisaggregationDecodeMixin,
46
46
  )
47
+ from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
48
+ DecodeKVCacheOffloadManager,
49
+ )
47
50
  from sglang.srt.disaggregation.prefill import (
48
51
  PrefillBootstrapQueue,
49
52
  SchedulerDisaggregationPrefillMixin,
@@ -57,11 +60,6 @@ from sglang.srt.disaggregation.utils import (
57
60
  )
58
61
  from sglang.srt.distributed import get_pp_group, get_world_group
59
62
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
60
- from sglang.srt.hf_transformers_utils import (
61
- get_processor,
62
- get_tokenizer,
63
- get_tokenizer_from_processor,
64
- )
65
63
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
64
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
65
  from sglang.srt.layers.moe import initialize_moe_config
@@ -72,20 +70,26 @@ from sglang.srt.managers.io_struct import (
72
70
  ClearHiCacheReqInput,
73
71
  ClearHiCacheReqOutput,
74
72
  CloseSessionReqInput,
73
+ DestroyWeightsUpdateGroupReqInput,
75
74
  ExpertDistributionReq,
76
75
  ExpertDistributionReqOutput,
76
+ ExpertDistributionReqType,
77
77
  FlushCacheReqInput,
78
78
  FlushCacheReqOutput,
79
79
  FreezeGCReq,
80
80
  GetInternalStateReq,
81
81
  GetInternalStateReqOutput,
82
+ GetLoadReqInput,
83
+ GetLoadReqOutput,
82
84
  GetWeightsByNameReqInput,
83
85
  HealthCheckOutput,
86
+ InitWeightsSendGroupForRemoteInstanceReqInput,
87
+ InitWeightsSendGroupForRemoteInstanceReqOutput,
84
88
  InitWeightsUpdateGroupReqInput,
85
89
  LoadLoRAAdapterReqInput,
86
90
  LoadLoRAAdapterReqOutput,
87
91
  MultiTokenizerRegisterReq,
88
- MultiTokenizerWarpper,
92
+ MultiTokenizerWrapper,
89
93
  OpenSessionReqInput,
90
94
  OpenSessionReqOutput,
91
95
  ProfileReq,
@@ -93,6 +97,8 @@ from sglang.srt.managers.io_struct import (
93
97
  ResumeMemoryOccupationReqInput,
94
98
  RpcReqInput,
95
99
  RpcReqOutput,
100
+ SendWeightsToRemoteInstanceReqInput,
101
+ SendWeightsToRemoteInstanceReqOutput,
96
102
  SetInternalStateReq,
97
103
  SetInternalStateReqOutput,
98
104
  SlowDownReqInput,
@@ -110,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
110
116
  FINISH_ABORT,
111
117
  MultimodalInputs,
112
118
  Req,
119
+ RequestStage,
113
120
  ScheduleBatch,
114
121
  global_server_args_dict,
115
122
  )
@@ -134,17 +141,28 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
134
141
  from sglang.srt.managers.session_controller import Session
135
142
  from sglang.srt.managers.tp_worker import TpModelWorker
136
143
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
137
- from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
144
+ from sglang.srt.managers.utils import validate_input_length
138
145
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
139
146
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
140
- from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
141
147
  from sglang.srt.mem_cache.radix_cache import RadixCache
142
148
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
143
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
144
- from sglang.srt.reasoning_parser import ReasoningParser
149
+ from sglang.srt.model_executor.forward_batch_info import (
150
+ ForwardBatchOutput,
151
+ ForwardMode,
152
+ PPProxyTensors,
153
+ )
154
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
145
155
  from sglang.srt.server_args import PortArgs, ServerArgs
146
156
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
147
157
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
158
+ from sglang.srt.tracing.trace import (
159
+ process_tracing_init,
160
+ trace_set_proc_propagate_context,
161
+ trace_set_thread_info,
162
+ trace_slice_batch,
163
+ trace_slice_end,
164
+ trace_slice_start,
165
+ )
148
166
  from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
149
167
  from sglang.srt.utils import (
150
168
  DynamicGradMode,
@@ -155,9 +173,10 @@ from sglang.srt.utils import (
155
173
  freeze_gc,
156
174
  get_available_gpu_memory,
157
175
  get_bool_env_var,
176
+ get_int_env_var,
158
177
  get_zmq_socket,
159
- is_cpu,
160
178
  kill_itself_when_parent_died,
179
+ numa_bind_to_node,
161
180
  point_to_point_pyobj,
162
181
  pyspy_dump_schedulers,
163
182
  require_mlp_sync,
@@ -166,6 +185,11 @@ from sglang.srt.utils import (
166
185
  set_random_seed,
167
186
  suppress_other_loggers,
168
187
  )
188
+ from sglang.srt.utils.hf_transformers_utils import (
189
+ get_processor,
190
+ get_tokenizer,
191
+ get_tokenizer_from_processor,
192
+ )
169
193
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
170
194
 
171
195
  logger = logging.getLogger(__name__)
@@ -174,24 +198,59 @@ logger = logging.getLogger(__name__)
174
198
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
175
199
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
176
200
 
177
- _is_cpu = is_cpu()
178
-
179
201
 
180
202
  @dataclass
181
203
  class GenerationBatchResult:
182
204
  logits_output: Optional[LogitsProcessorOutput]
183
- pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
205
+ pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
184
206
  next_token_ids: Optional[List[int]]
207
+ can_run_cuda_graph: bool
208
+
209
+ # For output processing
185
210
  extend_input_len_per_req: List[int]
186
211
  extend_logprob_start_len_per_req: List[int]
187
- bid: int
188
- can_run_cuda_graph: bool
212
+
213
+ @classmethod
214
+ def from_forward_batch_output(
215
+ cls,
216
+ forward_batch_output: ForwardBatchOutput,
217
+ extend_input_len_per_req: List[int],
218
+ extend_logprob_start_len_per_req: List[int],
219
+ ):
220
+ # TODO(lsyin): remove this workaround logic and try to unify output classes
221
+
222
+ return cls(
223
+ logits_output=forward_batch_output.logits_output,
224
+ pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
225
+ next_token_ids=forward_batch_output.next_token_ids,
226
+ extend_input_len_per_req=extend_input_len_per_req,
227
+ extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
228
+ can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
229
+ )
230
+
231
+ @classmethod
232
+ def from_pp_proxy(
233
+ cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
234
+ ):
235
+ # TODO(lsyin): also simplify this logic
236
+ # Current PP implementation in scheduler is not compatible with ForwardBatchOutput
237
+ # Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
238
+ proxy_dict = next_pp_outputs.tensors
239
+ return cls(
240
+ logits_output=logits_output,
241
+ pp_hidden_states_proxy_tensors=None,
242
+ next_token_ids=next_pp_outputs["next_token_ids"],
243
+ extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
244
+ extend_logprob_start_len_per_req=proxy_dict.get(
245
+ "extend_logprob_start_len_per_req", None
246
+ ),
247
+ can_run_cuda_graph=can_run_cuda_graph,
248
+ )
189
249
 
190
250
 
191
251
  @dataclass
192
252
  class EmbeddingBatchResult:
193
253
  embeddings: torch.Tensor
194
- bid: int
195
254
 
196
255
 
197
256
  class Scheduler(
@@ -213,7 +272,6 @@ class Scheduler(
213
272
  moe_ep_rank: int,
214
273
  pp_rank: int,
215
274
  dp_rank: Optional[int],
216
- dp_balance_meta: Optional[DPBalanceMeta] = None,
217
275
  ):
218
276
  # Parse args
219
277
  self.server_args = server_args
@@ -226,6 +284,13 @@ class Scheduler(
226
284
  self.pp_size = server_args.pp_size
227
285
  self.dp_size = server_args.dp_size
228
286
  self.schedule_policy = server_args.schedule_policy
287
+ self.enable_priority_scheduling = server_args.enable_priority_scheduling
288
+ self.schedule_low_priority_values_first = (
289
+ server_args.schedule_low_priority_values_first
290
+ )
291
+ self.priority_scheduling_preemption_threshold = (
292
+ server_args.priority_scheduling_preemption_threshold
293
+ )
229
294
  self.enable_lora = server_args.enable_lora
230
295
  self.max_loras_per_batch = server_args.max_loras_per_batch
231
296
  self.enable_overlap = not server_args.disable_overlap_schedule
@@ -234,7 +299,10 @@ class Scheduler(
234
299
  self.enable_metrics_for_all_schedulers = (
235
300
  server_args.enable_metrics_for_all_schedulers
236
301
  )
237
- self.enable_kv_cache_events = server_args.kv_events_config is not None
302
+ self.enable_kv_cache_events = bool(
303
+ server_args.kv_events_config and tp_rank == 0
304
+ )
305
+ self.enable_trace = server_args.enable_trace
238
306
  self.stream_interval = server_args.stream_interval
239
307
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
240
308
  server_args.speculative_algorithm
@@ -348,9 +416,39 @@ class Scheduler(
348
416
  target_worker=self.tp_worker,
349
417
  dp_rank=dp_rank,
350
418
  )
419
+ elif self.spec_algorithm.is_standalone():
420
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
421
+
422
+ self.draft_worker = StandaloneWorker(
423
+ gpu_id=gpu_id,
424
+ tp_rank=tp_rank,
425
+ moe_ep_rank=moe_ep_rank,
426
+ server_args=server_args,
427
+ nccl_port=port_args.nccl_port,
428
+ target_worker=self.tp_worker,
429
+ dp_rank=dp_rank,
430
+ )
431
+ elif self.spec_algorithm.is_ngram():
432
+ from sglang.srt.speculative.ngram_worker import NGRAMWorker
433
+
434
+ self.draft_worker = NGRAMWorker(
435
+ gpu_id=gpu_id,
436
+ tp_rank=tp_rank,
437
+ moe_ep_rank=moe_ep_rank,
438
+ server_args=server_args,
439
+ nccl_port=port_args.nccl_port,
440
+ target_worker=self.tp_worker,
441
+ dp_rank=dp_rank,
442
+ )
351
443
  else:
352
444
  self.draft_worker = None
353
445
 
446
+ # Dispatch the model worker
447
+ if self.spec_algorithm.is_none():
448
+ self.model_worker = self.tp_worker
449
+ else:
450
+ self.model_worker = self.draft_worker
451
+
354
452
  # Get token and memory info from the model worker
355
453
  (
356
454
  self.max_total_num_tokens,
@@ -401,7 +499,7 @@ class Scheduler(
401
499
  f"max_prefill_tokens={self.max_prefill_tokens}, "
402
500
  f"max_running_requests={self.max_running_requests}, "
403
501
  f"context_len={self.model_config.context_len}, "
404
- f"available_gpu_mem={avail_mem:.2f} GB"
502
+ f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
405
503
  )
406
504
 
407
505
  # Init memory pool and cache
@@ -458,7 +556,12 @@ class Scheduler(
458
556
  self.schedule_policy,
459
557
  self.tree_cache,
460
558
  self.enable_hierarchical_cache,
559
+ self.enable_priority_scheduling,
560
+ self.schedule_low_priority_values_first,
461
561
  )
562
+ # Enable preemption for priority scheduling.
563
+ self.try_preemption = self.enable_priority_scheduling
564
+
462
565
  assert (
463
566
  server_args.schedule_conservativeness >= 0
464
567
  ), "Invalid schedule_conservativeness"
@@ -488,7 +591,7 @@ class Scheduler(
488
591
  enable=server_args.enable_memory_saver
489
592
  )
490
593
  self.offload_tags = set()
491
- self.init_profier()
594
+ self.init_profiler()
492
595
 
493
596
  self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
494
597
  self.input_blocker = (
@@ -499,7 +602,9 @@ class Scheduler(
499
602
 
500
603
  # Init metrics stats
501
604
  self.init_metrics(tp_rank, pp_rank, dp_rank)
502
- self.init_kv_events(server_args.kv_events_config)
605
+
606
+ if self.enable_kv_cache_events:
607
+ self.init_kv_events(server_args.kv_events_config)
503
608
 
504
609
  # Init disaggregation
505
610
  self.disaggregation_mode = DisaggregationMode(
@@ -510,6 +615,9 @@ class Scheduler(
510
615
  if get_bool_env_var("SGLANG_GC_LOG"):
511
616
  configure_gc_logger()
512
617
 
618
+ # Init prefill kv split size when deterministic inference is enabled with various attention backends
619
+ self.init_deterministic_inference_config()
620
+
513
621
  # Init request dispatcher
514
622
  self._request_dispatcher = TypeBasedDispatcher(
515
623
  [
@@ -524,6 +632,15 @@ class Scheduler(
524
632
  (CloseSessionReqInput, self.close_session),
525
633
  (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
526
634
  (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
635
+ (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
636
+ (
637
+ InitWeightsSendGroupForRemoteInstanceReqInput,
638
+ self.init_weights_send_group_for_remote_instance,
639
+ ),
640
+ (
641
+ SendWeightsToRemoteInstanceReqInput,
642
+ self.send_weights_to_remote_instance,
643
+ ),
527
644
  (
528
645
  UpdateWeightsFromDistributedReqInput,
529
646
  self.update_weights_from_distributed,
@@ -542,17 +659,26 @@ class Scheduler(
542
659
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
543
660
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
544
661
  (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
662
+ (GetLoadReqInput, self.get_load),
545
663
  ]
546
664
  )
547
665
 
548
- self.balance_meta = dp_balance_meta
549
- if (
550
- server_args.enable_dp_attention
551
- and server_args.load_balance_method == "minimum_tokens"
552
- ):
553
- assert dp_balance_meta is not None
666
+ def init_deterministic_inference_config(self):
667
+ """Initialize deterministic inference configuration for different attention backends."""
668
+ if not self.server_args.enable_deterministic_inference:
669
+ self.truncation_align_size = None
670
+ return
554
671
 
555
- self.recv_dp_balance_id_this_term = []
672
+ backend_sizes = {
673
+ "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
674
+ "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
675
+ }
676
+ env_var, default_size = backend_sizes.get(
677
+ self.server_args.attention_backend, (None, None)
678
+ )
679
+ self.truncation_align_size = (
680
+ get_int_env_var(env_var, default_size) if env_var else None
681
+ )
556
682
 
557
683
  def init_tokenizer(self):
558
684
  server_args = self.server_args
@@ -625,15 +751,18 @@ class Scheduler(
625
751
  else self.tp_cpu_group
626
752
  ),
627
753
  page_size=self.page_size,
754
+ eviction_policy=server_args.radix_eviction_policy,
628
755
  hicache_ratio=server_args.hicache_ratio,
629
756
  hicache_size=server_args.hicache_size,
630
757
  hicache_write_policy=server_args.hicache_write_policy,
631
758
  hicache_io_backend=server_args.hicache_io_backend,
632
759
  hicache_mem_layout=server_args.hicache_mem_layout,
760
+ enable_metrics=self.enable_metrics,
633
761
  hicache_storage_backend=server_args.hicache_storage_backend,
634
762
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
635
763
  model_name=server_args.served_model_name,
636
764
  storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
765
+ is_eagle=self.spec_algorithm.is_eagle(),
637
766
  )
638
767
  self.tp_worker.register_hicache_layer_transfer_counter(
639
768
  self.tree_cache.cache_controller.layer_done_counter
@@ -649,18 +778,21 @@ class Scheduler(
649
778
  page_size=self.page_size,
650
779
  disable=server_args.disable_radix_cache,
651
780
  )
652
- elif self.enable_lora:
653
- assert (
654
- not self.enable_hierarchical_cache
655
- ), "LoRA radix cache doesn't support hierarchical cache"
656
- assert (
657
- self.schedule_policy == "fcfs"
658
- ), "LoRA radix cache only supports FCFS policy"
659
- self.tree_cache = LoRARadixCache(
781
+ elif server_args.enable_lmcache:
782
+ from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
783
+ LMCRadixCache,
784
+ )
785
+
786
+ self.tree_cache = LMCRadixCache(
660
787
  req_to_token_pool=self.req_to_token_pool,
661
788
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
662
789
  page_size=self.page_size,
663
790
  disable=server_args.disable_radix_cache,
791
+ model_config=self.model_config,
792
+ tp_size=self.tp_size,
793
+ rank=self.tp_rank,
794
+ tp_group=self.tp_group,
795
+ eviction_policy=server_args.radix_eviction_policy,
664
796
  )
665
797
  else:
666
798
  self.tree_cache = RadixCache(
@@ -669,16 +801,36 @@ class Scheduler(
669
801
  page_size=self.page_size,
670
802
  disable=server_args.disable_radix_cache,
671
803
  enable_kv_cache_events=self.enable_kv_cache_events,
804
+ eviction_policy=server_args.radix_eviction_policy,
805
+ is_eagle=self.spec_algorithm.is_eagle(),
672
806
  )
673
807
 
808
+ if (
809
+ server_args.disaggregation_mode == "decode"
810
+ and server_args.disaggregation_decode_enable_offload_kvcache
811
+ ):
812
+ self.decode_offload_manager = DecodeKVCacheOffloadManager(
813
+ req_to_token_pool=self.req_to_token_pool,
814
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
815
+ tp_group=(
816
+ self.attn_tp_cpu_group
817
+ if self.server_args.enable_dp_attention
818
+ else self.tp_cpu_group
819
+ ),
820
+ tree_cache=self.tree_cache,
821
+ server_args=self.server_args,
822
+ )
823
+ else:
824
+ self.decode_offload_manager = None
825
+
674
826
  self.decode_mem_cache_buf_multiplier = (
675
827
  1
676
828
  if self.spec_algorithm.is_none()
677
829
  else (
678
830
  server_args.speculative_num_draft_tokens
679
831
  + (
680
- server_args.speculative_eagle_topk
681
- * server_args.speculative_num_steps
832
+ (server_args.speculative_eagle_topk or 1)
833
+ * (server_args.speculative_num_steps or 1)
682
834
  )
683
835
  )
684
836
  )
@@ -701,7 +853,7 @@ class Scheduler(
701
853
  self.disagg_metadata_buffers = MetadataBuffers(
702
854
  buffer_size,
703
855
  hidden_size=self.model_config.hf_text_config.hidden_size,
704
- dtype=self.model_config.dtype,
856
+ hidden_states_dtype=self.model_config.dtype,
705
857
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
706
858
  )
707
859
 
@@ -721,7 +873,7 @@ class Scheduler(
721
873
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
722
874
  draft_token_to_kv_pool=(
723
875
  None
724
- if self.draft_worker is None
876
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
725
877
  else self.draft_worker.model_runner.token_to_kv_pool
726
878
  ),
727
879
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -750,7 +902,7 @@ class Scheduler(
750
902
  self.disagg_metadata_buffers = MetadataBuffers(
751
903
  buffer_size,
752
904
  hidden_size=self.model_config.hf_text_config.hidden_size,
753
- dtype=self.model_config.dtype,
905
+ hidden_states_dtype=self.model_config.dtype,
754
906
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
755
907
  )
756
908
 
@@ -758,7 +910,7 @@ class Scheduler(
758
910
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
759
911
  draft_token_to_kv_pool=(
760
912
  None
761
- if self.draft_worker is None
913
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
762
914
  else self.draft_worker.model_runner.token_to_kv_pool
763
915
  ),
764
916
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -853,7 +1005,6 @@ class Scheduler(
853
1005
  self.running_mbs = [
854
1006
  ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
855
1007
  ]
856
- bids = [None] * self.pp_size
857
1008
  pp_outputs: Optional[PPProxyTensors] = None
858
1009
  while True:
859
1010
  server_is_idle = True
@@ -874,10 +1025,7 @@ class Scheduler(
874
1025
  # (last rank) send the outputs to the next step
875
1026
  if self.pp_group.is_last_rank:
876
1027
  if self.cur_batch:
877
- next_token_ids, bids[mb_id] = (
878
- result.next_token_ids,
879
- result.bid,
880
- )
1028
+ next_token_ids = result.next_token_ids
881
1029
  if self.cur_batch.return_logprob:
882
1030
  pp_outputs = PPProxyTensors(
883
1031
  {
@@ -925,17 +1073,10 @@ class Scheduler(
925
1073
  logits_output = LogitsProcessorOutput(**logits_output_args)
926
1074
  else:
927
1075
  logits_output = None
928
- output_result = GenerationBatchResult(
1076
+
1077
+ output_result = GenerationBatchResult.from_pp_proxy(
929
1078
  logits_output=logits_output,
930
- pp_hidden_states_proxy_tensors=None,
931
- next_token_ids=next_pp_outputs["next_token_ids"],
932
- extend_input_len_per_req=next_pp_outputs.tensors.get(
933
- "extend_input_len_per_req", None
934
- ),
935
- extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
936
- "extend_logprob_start_len_per_req", None
937
- ),
938
- bid=bids[next_mb_id],
1079
+ next_pp_outputs=next_pp_outputs,
939
1080
  can_run_cuda_graph=result.can_run_cuda_graph,
940
1081
  )
941
1082
  self.process_batch_result(mbs[next_mb_id], output_result)
@@ -943,8 +1084,6 @@ class Scheduler(
943
1084
 
944
1085
  # (not last rank)
945
1086
  if not self.pp_group.is_last_rank:
946
- if self.cur_batch:
947
- bids[mb_id] = result.bid
948
1087
  # carry the outputs to the next stage
949
1088
  # send the outputs from the last round to let the next stage worker run post processing
950
1089
  if pp_outputs:
@@ -966,8 +1105,10 @@ class Scheduler(
966
1105
 
967
1106
  # send out proxy tensors to the next stage
968
1107
  if self.cur_batch:
1108
+ # FIXME(lsyin): remove this assert
1109
+ assert result.pp_hidden_states_proxy_tensors.tensors is not None
969
1110
  self.pp_group.send_tensor_dict(
970
- result.pp_hidden_states_proxy_tensors,
1111
+ result.pp_hidden_states_proxy_tensors.tensors,
971
1112
  all_gather_group=self.attn_tp_group,
972
1113
  )
973
1114
 
@@ -1077,6 +1218,15 @@ class Scheduler(
1077
1218
  self.tp_cpu_group,
1078
1219
  src=self.tp_group.ranks[0],
1079
1220
  )
1221
+
1222
+ if self.enable_trace:
1223
+ for req in recv_reqs:
1224
+ if isinstance(
1225
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1226
+ ):
1227
+ trace_set_proc_propagate_context(req.rid, req.trace_context)
1228
+ trace_slice_start("", req.rid, anonymous=True)
1229
+
1080
1230
  return recv_reqs
1081
1231
 
1082
1232
  def process_input_requests(self, recv_reqs: List):
@@ -1090,27 +1240,13 @@ class Scheduler(
1090
1240
  self.return_health_check_ct += 1
1091
1241
  continue
1092
1242
 
1093
- # If it is a work request, accept or reject the request based on the request queue size.
1094
- if is_work_request(recv_req):
1095
- if len(self.waiting_queue) + 1 > self.max_queued_requests:
1096
- abort_req = AbortReq(
1097
- recv_req.rid,
1098
- finished_reason={
1099
- "type": "abort",
1100
- "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1101
- "message": "The request queue is full.",
1102
- },
1103
- )
1104
- self.send_to_tokenizer.send_pyobj(abort_req)
1105
- continue
1106
-
1107
- # If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
1108
- if isinstance(recv_req, MultiTokenizerWarpper):
1243
+ # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
1244
+ if isinstance(recv_req, MultiTokenizerWrapper):
1109
1245
  worker_id = recv_req.worker_id
1110
1246
  recv_req = recv_req.obj
1111
1247
  output = self._request_dispatcher(recv_req)
1112
1248
  if output is not None:
1113
- output = MultiTokenizerWarpper(worker_id, output)
1249
+ output = MultiTokenizerWrapper(worker_id, output)
1114
1250
  self.send_to_tokenizer.send_pyobj(output)
1115
1251
  continue
1116
1252
 
@@ -1122,16 +1258,20 @@ class Scheduler(
1122
1258
  else:
1123
1259
  self.send_to_tokenizer.send_pyobj(output)
1124
1260
 
1261
+ def init_req_max_new_tokens(self, req):
1262
+ req.sampling_params.max_new_tokens = min(
1263
+ (
1264
+ req.sampling_params.max_new_tokens
1265
+ if req.sampling_params.max_new_tokens is not None
1266
+ else 1 << 30
1267
+ ),
1268
+ self.max_req_len - len(req.origin_input_ids) - 1,
1269
+ )
1270
+
1125
1271
  def handle_generate_request(
1126
1272
  self,
1127
1273
  recv_req: TokenizedGenerateReqInput,
1128
1274
  ):
1129
- if (
1130
- self.server_args.enable_dp_attention
1131
- and self.server_args.load_balance_method == "minimum_tokens"
1132
- ):
1133
- self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1134
-
1135
1275
  # Create a new request
1136
1276
  if (
1137
1277
  recv_req.session_params is None
@@ -1165,8 +1305,13 @@ class Scheduler(
1165
1305
  bootstrap_host=recv_req.bootstrap_host,
1166
1306
  bootstrap_port=recv_req.bootstrap_port,
1167
1307
  bootstrap_room=recv_req.bootstrap_room,
1308
+ disagg_mode=self.disaggregation_mode,
1168
1309
  data_parallel_rank=recv_req.data_parallel_rank,
1169
1310
  vocab_size=self.model_config.vocab_size,
1311
+ priority=recv_req.priority,
1312
+ metrics_collector=(
1313
+ self.metrics_collector if self.enable_metrics else None
1314
+ ),
1170
1315
  )
1171
1316
  req.tokenizer = self.tokenizer
1172
1317
 
@@ -1189,6 +1334,7 @@ class Scheduler(
1189
1334
  req.set_finish_with_abort(
1190
1335
  f"Invalid request: session id {recv_req.session_params.id} does not exist"
1191
1336
  )
1337
+ self.init_req_max_new_tokens(req)
1192
1338
  self._add_request_to_queue(req)
1193
1339
  return
1194
1340
  else:
@@ -1196,6 +1342,7 @@ class Scheduler(
1196
1342
  session = self.sessions[recv_req.session_params.id]
1197
1343
  req = session.create_req(recv_req, self.tokenizer)
1198
1344
  if isinstance(req.finished_reason, FINISH_ABORT):
1345
+ self.init_req_max_new_tokens(req)
1199
1346
  self._add_request_to_queue(req)
1200
1347
  return
1201
1348
 
@@ -1215,9 +1362,13 @@ class Scheduler(
1215
1362
  f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
1216
1363
  )
1217
1364
  )
1365
+ self.init_req_max_new_tokens(req)
1218
1366
  self._add_request_to_queue(req)
1219
1367
  return
1220
1368
 
1369
+ # initialize before returning
1370
+ self.init_req_max_new_tokens(req)
1371
+
1221
1372
  # Validate prompt length
1222
1373
  error_msg = validate_input_length(
1223
1374
  req,
@@ -1232,26 +1383,25 @@ class Scheduler(
1232
1383
  # Copy more attributes
1233
1384
  if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1234
1385
  # By default, only return the logprobs for output tokens
1235
- req.logprob_start_len = len(req.origin_input_ids) - 1
1386
+ # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
1387
+ # to skip input logprob computation entirely
1388
+ if req.is_prefill_only:
1389
+ req.logprob_start_len = len(req.origin_input_ids)
1390
+ else:
1391
+ # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
1392
+ req.logprob_start_len = len(req.origin_input_ids) - 1
1236
1393
  else:
1237
1394
  req.logprob_start_len = recv_req.logprob_start_len
1238
1395
 
1239
- if req.logprob_start_len >= len(req.origin_input_ids):
1396
+ if not req.is_prefill_only and req.logprob_start_len >= len(
1397
+ req.origin_input_ids
1398
+ ):
1240
1399
  error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1241
1400
  req.logprob_start_len = len(req.origin_input_ids) - 1
1242
1401
  req.set_finish_with_abort(error_msg)
1243
1402
  self._add_request_to_queue(req)
1244
1403
  return
1245
1404
 
1246
- req.sampling_params.max_new_tokens = min(
1247
- (
1248
- req.sampling_params.max_new_tokens
1249
- if req.sampling_params.max_new_tokens is not None
1250
- else 1 << 30
1251
- ),
1252
- self.max_req_len - len(req.origin_input_ids) - 1,
1253
- )
1254
-
1255
1405
  # Init grammar cache for this request
1256
1406
  add_to_grammar_queue = False
1257
1407
  if (
@@ -1282,7 +1432,6 @@ class Scheduler(
1282
1432
  req.set_finish_with_abort(error_msg)
1283
1433
 
1284
1434
  if add_to_grammar_queue:
1285
- req.queue_time_start = time.perf_counter()
1286
1435
  self.grammar_queue.append(req)
1287
1436
  else:
1288
1437
  self._add_request_to_queue(req)
@@ -1298,19 +1447,6 @@ class Scheduler(
1298
1447
  for tokenized_req in recv_req:
1299
1448
  self.handle_generate_request(tokenized_req)
1300
1449
 
1301
- def _add_request_to_queue(self, req: Req):
1302
- req.queue_time_start = time.perf_counter()
1303
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1304
- self._prefetch_kvcache(req)
1305
- self.disagg_prefill_bootstrap_queue.add(
1306
- req, self.model_config.num_key_value_heads
1307
- )
1308
- elif self.disaggregation_mode == DisaggregationMode.DECODE:
1309
- self.disagg_decode_prealloc_queue.add(req)
1310
- else:
1311
- self._prefetch_kvcache(req)
1312
- self.waiting_queue.append(req)
1313
-
1314
1450
  def _prefetch_kvcache(self, req: Req):
1315
1451
  if self.enable_hicache_storage:
1316
1452
  req.init_next_round_input(self.tree_cache)
@@ -1324,16 +1460,87 @@ class Scheduler(
1324
1460
  req.rid, req.last_host_node, new_input_tokens, last_hash
1325
1461
  )
1326
1462
 
1327
- def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1328
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1329
- self.disagg_prefill_bootstrap_queue.extend(
1330
- reqs, self.model_config.num_key_value_heads
1463
+ def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
1464
+ if self.disaggregation_mode == DisaggregationMode.NULL:
1465
+ self._set_or_validate_priority(req)
1466
+ if self._abort_on_queued_limit(req):
1467
+ return
1468
+ self._prefetch_kvcache(req)
1469
+ self.waiting_queue.append(req)
1470
+ req.time_stats.wait_queue_entry_time = time.perf_counter()
1471
+ trace_slice_end("process req", req.rid, auto_next_anon=True)
1472
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
1473
+ self._prefetch_kvcache(req)
1474
+ self.disagg_prefill_bootstrap_queue.add(
1475
+ req, self.model_config.num_key_value_heads
1331
1476
  )
1477
+ req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
1332
1478
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1333
- # If this is a decode server, we put the request to the decode pending prealloc queue
1334
- self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
1479
+ self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
1480
+ if not is_retracted:
1481
+ req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
1335
1482
  else:
1336
- self.waiting_queue.extend(reqs)
1483
+ raise ValueError(f"Invalid {self.disaggregation_mode=}")
1484
+
1485
+ def _set_or_validate_priority(self, req: Req):
1486
+ """Set the default priority value, or abort the request based on the priority scheduling mode."""
1487
+ if self.enable_priority_scheduling and req.priority is None:
1488
+ if self.schedule_low_priority_values_first:
1489
+ req.priority = sys.maxsize
1490
+ else:
1491
+ req.priority = -sys.maxsize - 1
1492
+ elif not self.enable_priority_scheduling and req.priority is not None:
1493
+ abort_req = AbortReq(
1494
+ finished_reason={
1495
+ "type": "abort",
1496
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1497
+ "message": "Using priority is disabled for this server. Please send a new request without a priority.",
1498
+ },
1499
+ rid=req.rid,
1500
+ )
1501
+ self.send_to_tokenizer.send_pyobj(abort_req)
1502
+
1503
+ def _abort_on_queued_limit(self, recv_req: Req) -> bool:
1504
+ """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
1505
+ if (
1506
+ self.max_queued_requests is None
1507
+ or len(self.waiting_queue) + 1 <= self.max_queued_requests
1508
+ ):
1509
+ return False
1510
+
1511
+ # Reject the incoming request by default.
1512
+ req_to_abort = recv_req
1513
+ message = "The request queue is full."
1514
+ if self.enable_priority_scheduling:
1515
+ # With priority scheduling, consider aboritng an existing request based on the priority.
1516
+ # direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
1517
+ # max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
1518
+ # Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
1519
+ direction = 1 if self.schedule_low_priority_values_first else -1
1520
+ key_fn = lambda item: (
1521
+ direction * item[1].priority,
1522
+ item[1].time_stats.wait_queue_entry_time,
1523
+ )
1524
+ idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
1525
+ abort_existing_req = (
1526
+ direction * recv_req.priority < direction * candidate_req.priority
1527
+ )
1528
+ if abort_existing_req:
1529
+ self.waiting_queue.pop(idx)
1530
+ req_to_abort = candidate_req
1531
+ message = "The request is aborted by a higher priority request."
1532
+
1533
+ self.send_to_tokenizer.send_pyobj(
1534
+ AbortReq(
1535
+ finished_reason={
1536
+ "type": "abort",
1537
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1538
+ "message": message,
1539
+ },
1540
+ rid=req_to_abort.rid,
1541
+ )
1542
+ )
1543
+ return req_to_abort.rid == recv_req.rid
1337
1544
 
1338
1545
  def handle_embedding_request(
1339
1546
  self,
@@ -1345,6 +1552,7 @@ class Scheduler(
1345
1552
  recv_req.input_ids,
1346
1553
  recv_req.sampling_params,
1347
1554
  token_type_ids=recv_req.token_type_ids,
1555
+ priority=recv_req.priority,
1348
1556
  )
1349
1557
  req.tokenizer = self.tokenizer
1350
1558
 
@@ -1421,9 +1629,11 @@ class Scheduler(
1421
1629
  _, _, available_size, evictable_size = self._get_token_info()
1422
1630
  protected_size = self.tree_cache.protected_size()
1423
1631
  memory_leak = (available_size + evictable_size) != (
1632
+ # self.max_total_num_tokens
1633
+ # if not self.enable_hierarchical_cache
1634
+ # else self.max_total_num_tokens - protected_size
1424
1635
  self.max_total_num_tokens
1425
- if not self.enable_hierarchical_cache
1426
- else self.max_total_num_tokens - protected_size
1636
+ - protected_size
1427
1637
  )
1428
1638
  token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
1429
1639
 
@@ -1474,6 +1684,20 @@ class Scheduler(
1474
1684
  self.stats.gen_throughput = 0
1475
1685
  self.stats.num_queue_reqs = len(self.waiting_queue)
1476
1686
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1687
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1688
+ self.stats.num_prefill_prealloc_queue_reqs = len(
1689
+ self.disagg_prefill_bootstrap_queue.queue
1690
+ )
1691
+ self.stats.num_prefill_inflight_queue_reqs = len(
1692
+ self.disagg_prefill_inflight_queue
1693
+ )
1694
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1695
+ self.stats.num_decode_prealloc_queue_reqs = len(
1696
+ self.disagg_decode_prealloc_queue.queue
1697
+ )
1698
+ self.stats.num_decode_transfer_queue_reqs = len(
1699
+ self.disagg_decode_transfer_queue.queue
1700
+ )
1477
1701
  self.metrics_collector.log_stats(self.stats)
1478
1702
  self._publish_kv_events()
1479
1703
 
@@ -1521,7 +1745,12 @@ class Scheduler(
1521
1745
  chunked_req_to_exclude.add(self.chunked_req)
1522
1746
  self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1523
1747
  # chunked request keeps its rid but will get a new req_pool_idx
1524
- self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1748
+ if self.tp_worker.worker.model_runner.is_hybrid_gdn:
1749
+ self.req_to_token_pool.free(
1750
+ self.chunked_req.req_pool_idx, free_mamba_cache=False
1751
+ )
1752
+ else:
1753
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1525
1754
  if self.last_batch and self.last_batch.forward_mode.is_extend():
1526
1755
  if self.last_batch.chunked_req is not None:
1527
1756
  # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
@@ -1568,11 +1797,6 @@ class Scheduler(
1568
1797
 
1569
1798
  # Handle DP attention
1570
1799
  if need_dp_attn_preparation:
1571
- if (
1572
- self.server_args.load_balance_method == "minimum_tokens"
1573
- and self.forward_ct % 40 == 0
1574
- ):
1575
- self.handle_dp_balance_data(ret)
1576
1800
  ret = self.prepare_mlp_sync_batch(ret)
1577
1801
 
1578
1802
  return ret
@@ -1588,6 +1812,10 @@ class Scheduler(
1588
1812
  if self.grammar_queue:
1589
1813
  self.move_ready_grammar_requests()
1590
1814
 
1815
+ if self.try_preemption:
1816
+ # Reset batch_is_full to try preemption with a prefill adder.
1817
+ self.running_batch.batch_is_full = False
1818
+
1591
1819
  # Handle the cases where prefill is not allowed
1592
1820
  if (
1593
1821
  self.running_batch.batch_is_full or len(self.waiting_queue) == 0
@@ -1600,7 +1828,11 @@ class Scheduler(
1600
1828
  # as the space for the chunked request has just been released.
1601
1829
  # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
1602
1830
  # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1603
- if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
1831
+ if (
1832
+ self.get_num_allocatable_reqs(running_bs) <= 0
1833
+ and not self.chunked_req
1834
+ and not self.try_preemption
1835
+ ):
1604
1836
  self.running_batch.batch_is_full = True
1605
1837
  return None
1606
1838
 
@@ -1620,6 +1852,7 @@ class Scheduler(
1620
1852
  self.max_prefill_tokens,
1621
1853
  self.chunked_prefill_size,
1622
1854
  running_bs if self.is_mixed_chunk else 0,
1855
+ self.priority_scheduling_preemption_threshold,
1623
1856
  )
1624
1857
 
1625
1858
  if self.chunked_req is not None:
@@ -1640,15 +1873,19 @@ class Scheduler(
1640
1873
  self.running_batch.batch_is_full = True
1641
1874
  break
1642
1875
 
1876
+ running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
1643
1877
  if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
1644
1878
  self.running_batch.batch_is_full = True
1645
- break
1646
-
1647
1879
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1648
1880
  # In prefill mode, prealloc queue and transfer queue can also take memory,
1649
1881
  # so we need to check if the available size for the actual available size.
1650
1882
  if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
1651
1883
  self.running_batch.batch_is_full = True
1884
+
1885
+ if self.running_batch.batch_is_full:
1886
+ if not self.try_preemption:
1887
+ break
1888
+ if not adder.preempt_to_schedule(req, self.server_args):
1652
1889
  break
1653
1890
 
1654
1891
  if self.enable_hicache_storage:
@@ -1658,7 +1895,11 @@ class Scheduler(
1658
1895
  continue
1659
1896
 
1660
1897
  req.init_next_round_input(self.tree_cache)
1661
- res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1898
+ res = adder.add_one_req(
1899
+ req,
1900
+ has_chunked_req=(self.chunked_req is not None),
1901
+ truncation_align_size=self.truncation_align_size,
1902
+ )
1662
1903
 
1663
1904
  if res != AddReqResult.CONTINUE:
1664
1905
  if res == AddReqResult.NO_TOKEN:
@@ -1679,11 +1920,14 @@ class Scheduler(
1679
1920
  if self.enable_metrics:
1680
1921
  # only record queue time when enable_metrics is True to avoid overhead
1681
1922
  for req in can_run_list:
1682
- req.queue_time_end = time.perf_counter()
1923
+ req.add_latency(RequestStage.PREFILL_WAITING)
1683
1924
 
1684
1925
  self.waiting_queue = [
1685
1926
  x for x in self.waiting_queue if x not in set(can_run_list)
1686
1927
  ]
1928
+ if adder.preempt_list:
1929
+ for req in adder.preempt_list:
1930
+ self._add_request_to_queue(req)
1687
1931
 
1688
1932
  if adder.new_chunked_req is not None:
1689
1933
  assert self.chunked_req is None
@@ -1694,7 +1938,16 @@ class Scheduler(
1694
1938
 
1695
1939
  # Print stats
1696
1940
  if self.current_scheduler_metrics_enabled():
1697
- self.log_prefill_stats(adder, can_run_list, running_bs)
1941
+ self.log_prefill_stats(adder, can_run_list, running_bs, 0)
1942
+
1943
+ for req in can_run_list:
1944
+ if req.time_stats.forward_entry_time == 0:
1945
+ # Avoid update chunked request many times
1946
+ req.time_stats.forward_entry_time = time.perf_counter()
1947
+ if self.enable_metrics:
1948
+ self.metrics_collector.observe_queue_time(
1949
+ req.time_stats.get_queueing_time(),
1950
+ )
1698
1951
 
1699
1952
  # Create a new batch
1700
1953
  new_batch = ScheduleBatch.init_new(
@@ -1749,19 +2002,25 @@ class Scheduler(
1749
2002
  TEST_RETRACT and batch.batch_size() > 10
1750
2003
  ):
1751
2004
  old_ratio = self.new_token_ratio
1752
-
1753
- retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1754
- num_retracted_reqs = len(retracted_reqs)
2005
+ retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
2006
+ self.server_args
2007
+ )
2008
+ self.num_retracted_reqs = len(retracted_reqs)
1755
2009
  self.new_token_ratio = new_token_ratio
2010
+ for req in reqs_to_abort:
2011
+ self.send_to_tokenizer.send_pyobj(
2012
+ AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
2013
+ )
1756
2014
 
1757
2015
  logger.info(
1758
2016
  "KV cache pool is full. Retract requests. "
1759
- f"#retracted_reqs: {num_retracted_reqs}, "
1760
- f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
2017
+ f"#retracted_reqs: {len(retracted_reqs)}, "
2018
+ f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
2019
+ f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
1761
2020
  )
1762
2021
 
1763
- self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1764
- self.total_retracted_reqs += num_retracted_reqs
2022
+ for req in retracted_reqs:
2023
+ self._add_request_to_queue(req, is_retracted=True)
1765
2024
  else:
1766
2025
  self.new_token_ratio = max(
1767
2026
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -1789,37 +2048,25 @@ class Scheduler(
1789
2048
 
1790
2049
  # Run forward
1791
2050
  if self.is_generation:
2051
+
2052
+ batch_or_worker_batch = batch
2053
+
1792
2054
  if self.spec_algorithm.is_none():
1793
- model_worker_batch = batch.get_model_worker_batch()
2055
+ # FIXME(lsyin): remove this if and finally unify the abstraction
2056
+ batch_or_worker_batch = batch.get_model_worker_batch()
1794
2057
 
1795
- # update the consumer index of hicache to the running batch
1796
- self.tp_worker.set_hicache_consumer(
1797
- model_worker_batch.hicache_consumer_index
2058
+ forward_batch_output = self.model_worker.forward_batch_generation(
2059
+ batch_or_worker_batch
2060
+ )
2061
+
2062
+ if not self.spec_algorithm.is_none():
2063
+ # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
2064
+ self.udpate_spec_metrics(
2065
+ batch.batch_size(), forward_batch_output.num_accepted_tokens
1798
2066
  )
1799
- if self.pp_group.is_last_rank:
1800
- logits_output, next_token_ids, can_run_cuda_graph = (
1801
- self.tp_worker.forward_batch_generation(model_worker_batch)
1802
- )
1803
- else:
1804
- pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1805
- self.tp_worker.forward_batch_generation(model_worker_batch)
1806
- )
1807
- bid = model_worker_batch.bid
1808
- else:
1809
- (
1810
- logits_output,
1811
- next_token_ids,
1812
- bid,
1813
- num_accepted_tokens,
1814
- can_run_cuda_graph,
1815
- ) = self.draft_worker.forward_batch_speculative_generation(batch)
1816
- bs = batch.batch_size()
1817
- self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
1818
- self.spec_num_total_forward_ct += bs
1819
- self.num_generated_tokens += num_accepted_tokens
1820
-
1821
- if self.pp_group.is_last_rank:
1822
- batch.output_ids = next_token_ids
2067
+
2068
+ # update batch's output ids
2069
+ batch.output_ids = forward_batch_output.next_token_ids
1823
2070
 
1824
2071
  # These 2 values are needed for processing the output, but the values can be
1825
2072
  # modified by overlap schedule. So we have to copy them here so that
@@ -1828,6 +2075,7 @@ class Scheduler(
1828
2075
  extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1829
2076
  else:
1830
2077
  extend_input_len_per_req = None
2078
+
1831
2079
  if batch.return_logprob:
1832
2080
  extend_logprob_start_len_per_req = [
1833
2081
  req.extend_logprob_start_len for req in batch.reqs
@@ -1835,25 +2083,15 @@ class Scheduler(
1835
2083
  else:
1836
2084
  extend_logprob_start_len_per_req = None
1837
2085
 
1838
- ret = GenerationBatchResult(
1839
- logits_output=logits_output if self.pp_group.is_last_rank else None,
1840
- pp_hidden_states_proxy_tensors=(
1841
- pp_hidden_states_proxy_tensors
1842
- if not self.pp_group.is_last_rank
1843
- else None
1844
- ),
1845
- next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
2086
+ return GenerationBatchResult.from_forward_batch_output(
2087
+ forward_batch_output=forward_batch_output,
1846
2088
  extend_input_len_per_req=extend_input_len_per_req,
1847
2089
  extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1848
- bid=bid,
1849
- can_run_cuda_graph=can_run_cuda_graph,
1850
2090
  )
1851
2091
  else: # embedding or reward model
1852
2092
  model_worker_batch = batch.get_model_worker_batch()
1853
2093
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1854
- ret = EmbeddingBatchResult(
1855
- embeddings=embeddings, bid=model_worker_batch.bid
1856
- )
2094
+ ret = EmbeddingBatchResult(embeddings=embeddings)
1857
2095
  return ret
1858
2096
 
1859
2097
  def process_batch_result(
@@ -1864,8 +2102,14 @@ class Scheduler(
1864
2102
  ):
1865
2103
  if batch.forward_mode.is_decode():
1866
2104
  self.process_batch_result_decode(batch, result, launch_done)
2105
+ if self.enable_trace:
2106
+ trace_slice_batch("decode loop", batch.reqs)
2107
+
1867
2108
  elif batch.forward_mode.is_extend():
1868
2109
  self.process_batch_result_prefill(batch, result, launch_done)
2110
+ if self.enable_trace:
2111
+ trace_slice_batch("prefill", batch.reqs)
2112
+
1869
2113
  elif batch.forward_mode.is_idle():
1870
2114
  if self.enable_overlap:
1871
2115
  self.tp_worker.resolve_last_batch_result(launch_done)
@@ -1897,86 +2141,6 @@ class Scheduler(
1897
2141
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1898
2142
  )
1899
2143
 
1900
- def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1901
- def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1902
- """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1903
- recv_list = self.recv_dp_balance_id_this_term
1904
- assert len(recv_list) <= 511, (
1905
- "The number of requests received this round is too large. "
1906
- "Please increase gather_tensor_size and onfly_info_size."
1907
- )
1908
- # The maximum size of the tensor used for gathering data from all workers.
1909
- gather_tensor_size = 512
1910
-
1911
- # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1912
- recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1913
- recv_tensor[0] = holding_tokens_list
1914
- recv_tensor[1] = len(
1915
- recv_list
1916
- ) # The first element is the length of the list.
1917
- recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1918
- recv_list, dtype=torch.int32
1919
- )
1920
-
1921
- if self.tp_rank == 0:
1922
- gathered_list = [
1923
- torch.zeros(gather_tensor_size, dtype=torch.int32)
1924
- for _ in range(self.balance_meta.num_workers)
1925
- ]
1926
- else:
1927
- gathered_list = None
1928
-
1929
- torch.distributed.gather(
1930
- recv_tensor, gathered_list, group=self.tp_cpu_group
1931
- )
1932
-
1933
- gathered_id_list_per_worker = None
1934
- if self.tp_rank == 0:
1935
- gathered_id_list_per_worker = []
1936
- holding_tokens_list = []
1937
- for tensor in gathered_list:
1938
- holding_tokens_list.append(tensor[0].item())
1939
- list_length = tensor[1].item()
1940
- gathered_id_list_per_worker.append(
1941
- tensor[2 : list_length + 2].tolist()
1942
- )
1943
-
1944
- return gathered_id_list_per_worker, holding_tokens_list
1945
-
1946
- def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1947
- meta = self.balance_meta
1948
-
1949
- with meta.mutex:
1950
- onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1951
- assert len(new_recv_rid_lists) == len(
1952
- onfly_list
1953
- ), "num_worker not equal"
1954
- # 1.Check if the rid received by each worker this round is present in onfly.
1955
- # If it is, remove the corresponding onfly item.
1956
- worker_id = 0
1957
- for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1958
- for new_recv_rid in new_recv_rids:
1959
- assert (
1960
- new_recv_rid in on_fly_reqs
1961
- ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1962
- del on_fly_reqs[new_recv_rid]
1963
- worker_id += 1
1964
- # 2. Atomically write local_tokens and onfly into shm under the mutex
1965
- meta.set_shared_onfly_info(onfly_list)
1966
- meta.set_shared_local_tokens(local_tokens)
1967
-
1968
- holding_tokens = self.get_load()
1969
-
1970
- new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1971
- holding_tokens
1972
- )
1973
-
1974
- self.recv_dp_balance_id_this_term.clear()
1975
- if self.tp_rank == 0: # only first worker write info
1976
- write_shared_dp_balance_info(
1977
- new_recv_dp_balance_id_list, holding_token_list
1978
- )
1979
-
1980
2144
  @staticmethod
1981
2145
  def prepare_mlp_sync_batch_raw(
1982
2146
  local_batch: ScheduleBatch,
@@ -2104,12 +2268,13 @@ class Scheduler(
2104
2268
  if req.finished(): # It is aborted by AbortReq
2105
2269
  num_ready_reqs += 1
2106
2270
  continue
2271
+
2107
2272
  req.grammar = req.grammar.result(timeout=0.03)
2108
2273
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2109
2274
  if req.grammar is INVALID_GRAMMAR_OBJ:
2110
- req.set_finish_with_abort(
2111
- f"Invalid grammar request: {req.grammar_key=}"
2112
- )
2275
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2276
+ req.set_finish_with_abort(error_msg)
2277
+
2113
2278
  num_ready_reqs += 1
2114
2279
  except futures._base.TimeoutError:
2115
2280
  req.grammar_wait_ct += 1
@@ -2141,9 +2306,8 @@ class Scheduler(
2141
2306
  req.grammar = req.grammar.result()
2142
2307
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2143
2308
  if req.grammar is INVALID_GRAMMAR_OBJ:
2144
- req.set_finish_with_abort(
2145
- f"Invalid grammar request: {req.grammar_key=}"
2146
- )
2309
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2310
+ req.set_finish_with_abort(error_msg)
2147
2311
  else:
2148
2312
  num_ready_reqs_max = num_ready_reqs
2149
2313
  num_timeout_reqs_max = num_timeout_reqs
@@ -2151,12 +2315,14 @@ class Scheduler(
2151
2315
  for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
2152
2316
  req = self.grammar_queue[i]
2153
2317
  req.grammar.cancel()
2318
+ self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2154
2319
  error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
2155
2320
  req.set_finish_with_abort(error_msg)
2156
- self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2321
+
2157
2322
  num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2158
2323
 
2159
- self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2324
+ for req in self.grammar_queue[:num_ready_reqs]:
2325
+ self._add_request_to_queue(req)
2160
2326
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
2161
2327
 
2162
2328
  def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
@@ -2248,9 +2414,8 @@ class Scheduler(
2248
2414
  self.req_to_token_pool.clear()
2249
2415
  self.token_to_kv_pool_allocator.clear()
2250
2416
 
2251
- if not self.spec_algorithm.is_none():
2252
- self.draft_worker.model_runner.req_to_token_pool.clear()
2253
- self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2417
+ if self.draft_worker:
2418
+ self.draft_worker.clear_cache_pool()
2254
2419
 
2255
2420
  self.num_generated_tokens = 0
2256
2421
  self.forward_ct_decode = 0
@@ -2270,39 +2435,50 @@ class Scheduler(
2270
2435
  if_success = False
2271
2436
  return if_success
2272
2437
 
2273
- def get_load(self):
2438
+ def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
2274
2439
  # TODO(lsyin): use dynamically maintained num_waiting_tokens
2440
+
2275
2441
  if self.is_hybrid:
2276
- load_full = (
2442
+ num_tokens_full = (
2277
2443
  self.full_tokens_per_layer
2278
2444
  - self.token_to_kv_pool_allocator.full_available_size()
2279
2445
  - self.tree_cache.full_evictable_size()
2280
2446
  )
2281
- load_swa = (
2447
+ num_tokens_swa = (
2282
2448
  self.swa_tokens_per_layer
2283
2449
  - self.token_to_kv_pool_allocator.swa_available_size()
2284
2450
  - self.tree_cache.swa_evictable_size()
2285
2451
  )
2286
- load = max(load_full, load_swa)
2452
+ num_tokens = max(num_tokens_full, num_tokens_swa)
2287
2453
  else:
2288
- load = (
2454
+ num_tokens = (
2289
2455
  self.max_total_num_tokens
2290
2456
  - self.token_to_kv_pool_allocator.available_size()
2291
2457
  - self.tree_cache.evictable_size()
2292
2458
  )
2293
- load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
2459
+
2460
+ # Tokens in waiting queue, bootstrap queue, prealloc queue
2461
+ num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
2462
+ num_waiting_reqs = len(self.waiting_queue)
2294
2463
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2295
- load += sum(
2464
+ num_tokens += sum(
2296
2465
  len(req.origin_input_ids)
2297
2466
  for req in self.disagg_prefill_bootstrap_queue.queue
2298
2467
  )
2468
+ num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
2299
2469
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
2300
- load += sum(
2470
+ num_tokens += sum(
2301
2471
  len(req.req.origin_input_ids)
2302
2472
  for req in self.disagg_decode_prealloc_queue.queue
2303
2473
  )
2474
+ num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
2304
2475
 
2305
- return load
2476
+ return GetLoadReqOutput(
2477
+ dp_rank=self.dp_rank,
2478
+ num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
2479
+ num_waiting_reqs=num_waiting_reqs,
2480
+ num_tokens=num_tokens,
2481
+ )
2306
2482
 
2307
2483
  def get_internal_state(self, recv_req: GetInternalStateReq):
2308
2484
  ret = dict(global_server_args_dict)
@@ -2317,10 +2493,9 @@ class Scheduler(
2317
2493
  "token_capacity": int(self.max_total_num_tokens),
2318
2494
  }
2319
2495
 
2320
- if not _is_cpu:
2321
- ret["memory_usage"]["cuda_graph"] = round(
2322
- self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
2323
- )
2496
+ ret["memory_usage"]["graph"] = round(
2497
+ self.tp_worker.worker.model_runner.graph_mem_usage, 2
2498
+ )
2324
2499
 
2325
2500
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2326
2501
  ret["avg_spec_accept_length"] = (
@@ -2329,8 +2504,6 @@ class Scheduler(
2329
2504
  if RECORD_STEP_TIME:
2330
2505
  ret["step_time_dict"] = self.step_time_dict
2331
2506
 
2332
- ret["load"] = self.get_load()
2333
-
2334
2507
  return GetInternalStateReqOutput(internal_state=ret)
2335
2508
 
2336
2509
  def set_internal_state(self, recv_req: SetInternalStateReq):
@@ -2406,7 +2579,7 @@ class Scheduler(
2406
2579
  if self.enable_hicache_storage:
2407
2580
  # to release prefetch events associated with the request
2408
2581
  self.tree_cache.release_aborted_request(req.rid)
2409
- self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2582
+ self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
2410
2583
  # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2411
2584
  if self.disaggregation_mode == DisaggregationMode.DECODE:
2412
2585
  self.tree_cache.cache_finished_req(req)
@@ -2427,31 +2600,31 @@ class Scheduler(
2427
2600
  # Delete requests not in the waiting queue when PD disaggregation is enabled
2428
2601
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2429
2602
  # Abort requests that have not yet been bootstrapped
2430
- for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
2431
- logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2603
+ for req in self.disagg_prefill_bootstrap_queue.queue:
2432
2604
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2605
+ logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2433
2606
  if hasattr(req.disagg_kv_sender, "abort"):
2434
2607
  req.disagg_kv_sender.abort()
2435
2608
 
2436
2609
  # Abort in-flight requests
2437
- for i, req in enumerate(self.disagg_prefill_inflight_queue):
2438
- logger.debug(f"Abort inflight queue request. {req.rid=}")
2610
+ for req in self.disagg_prefill_inflight_queue:
2439
2611
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2612
+ logger.debug(f"Abort inflight queue request. {req.rid=}")
2440
2613
  if hasattr(req.disagg_kv_sender, "abort"):
2441
2614
  req.disagg_kv_sender.abort()
2442
2615
 
2443
2616
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
2444
2617
  # Abort requests that have not yet finished preallocation
2445
- for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
2446
- logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2618
+ for decode_req in self.disagg_decode_prealloc_queue.queue:
2447
2619
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2620
+ logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2448
2621
  if hasattr(decode_req.kv_receiver, "abort"):
2449
2622
  decode_req.kv_receiver.abort()
2450
2623
 
2451
2624
  # Abort requests waiting for kvcache to release tree cache
2452
- for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
2453
- logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2625
+ for decode_req in self.disagg_decode_transfer_queue.queue:
2454
2626
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2627
+ logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2455
2628
  if hasattr(decode_req.kv_receiver, "abort"):
2456
2629
  decode_req.kv_receiver.abort()
2457
2630
 
@@ -2494,6 +2667,22 @@ class Scheduler(
2494
2667
  self.send_to_detokenizer.send_pyobj(recv_req)
2495
2668
  return recv_req
2496
2669
 
2670
+ def init_weights_send_group_for_remote_instance(
2671
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
2672
+ ):
2673
+ """Init the seed and client instance communication group."""
2674
+ success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
2675
+ recv_req
2676
+ )
2677
+ return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
2678
+
2679
+ def send_weights_to_remote_instance(
2680
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
2681
+ ):
2682
+ """Send the seed instance weights to the destination instance."""
2683
+ success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
2684
+ return SendWeightsToRemoteInstanceReqOutput(success, message)
2685
+
2497
2686
  def slow_down(self, recv_req: SlowDownReqInput):
2498
2687
  t = recv_req.forward_sleep_time
2499
2688
  if t is not None and t <= 0:
@@ -2502,11 +2691,12 @@ class Scheduler(
2502
2691
  return SlowDownReqOutput()
2503
2692
 
2504
2693
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2505
- if recv_req == ExpertDistributionReq.START_RECORD:
2694
+ action = recv_req.action
2695
+ if action == ExpertDistributionReqType.START_RECORD:
2506
2696
  get_global_expert_distribution_recorder().start_record()
2507
- elif recv_req == ExpertDistributionReq.STOP_RECORD:
2697
+ elif action == ExpertDistributionReqType.STOP_RECORD:
2508
2698
  get_global_expert_distribution_recorder().stop_record()
2509
- elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2699
+ elif action == ExpertDistributionReqType.DUMP_RECORD:
2510
2700
  get_global_expert_distribution_recorder().dump_record()
2511
2701
  else:
2512
2702
  raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
@@ -2589,7 +2779,8 @@ class IdleSleeper:
2589
2779
 
2590
2780
 
2591
2781
  def is_health_check_generate_req(recv_req):
2592
- return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2782
+ rid = getattr(recv_req, "rid", None)
2783
+ return rid is not None and rid.startswith("HEALTH_CHECK")
2593
2784
 
2594
2785
 
2595
2786
  def is_work_request(recv_req):
@@ -2613,10 +2804,12 @@ def run_scheduler_process(
2613
2804
  pp_rank: int,
2614
2805
  dp_rank: Optional[int],
2615
2806
  pipe_writer,
2616
- balance_meta: Optional[DPBalanceMeta] = None,
2617
2807
  ):
2618
- # Generate the prefix
2808
+ # Generate the logger prefix
2619
2809
  prefix = ""
2810
+ if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2811
+ # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2812
+ dp_rank = int(os.environ["SGLANG_DP_RANK"])
2620
2813
  if dp_rank is not None:
2621
2814
  prefix += f" DP{dp_rank}"
2622
2815
  if server_args.tp_size > 1:
@@ -2632,10 +2825,6 @@ def run_scheduler_process(
2632
2825
  kill_itself_when_parent_died()
2633
2826
  parent_process = psutil.Process().parent()
2634
2827
 
2635
- # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2636
- if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2637
- dp_rank = int(os.environ["SGLANG_DP_RANK"])
2638
-
2639
2828
  # Configure the logger
2640
2829
  configure_logger(server_args, prefix=prefix)
2641
2830
  suppress_other_loggers()
@@ -2643,6 +2832,15 @@ def run_scheduler_process(
2643
2832
  # Set cpu affinity to this gpu process
2644
2833
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2645
2834
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2835
+ if (numa_node := server_args.numa_node) is not None:
2836
+ numa_bind_to_node(numa_node[gpu_id])
2837
+
2838
+ # Set up tracing
2839
+ if server_args.enable_trace:
2840
+ process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2841
+ if server_args.disaggregation_mode == "null":
2842
+ thread_label = "Scheduler"
2843
+ trace_set_thread_info(thread_label, tp_rank, dp_rank)
2646
2844
 
2647
2845
  # Create a scheduler and run the event loop
2648
2846
  try:
@@ -2654,7 +2852,6 @@ def run_scheduler_process(
2654
2852
  moe_ep_rank,
2655
2853
  pp_rank,
2656
2854
  dp_rank,
2657
- dp_balance_meta=balance_meta,
2658
2855
  )
2659
2856
  pipe_writer.send(
2660
2857
  {