sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,14 @@ from concurrent import futures
25
25
  from dataclasses import dataclass
26
26
  from http import HTTPStatus
27
27
  from types import SimpleNamespace
28
- from typing import Dict, List, Optional, Tuple, Union
28
+ from typing import Deque, Dict, List, Optional, Tuple, Union
29
29
 
30
30
  import psutil
31
31
  import setproctitle
32
32
  import torch
33
33
  import zmq
34
+ from torch.cuda import Stream as CudaStream
35
+ from torch.cuda import StreamContext as CudaStreamContext
34
36
  from torch.distributed import barrier
35
37
 
36
38
  from sglang.global_config import global_config
@@ -44,6 +46,9 @@ from sglang.srt.disaggregation.decode import (
44
46
  DecodeTransferQueue,
45
47
  SchedulerDisaggregationDecodeMixin,
46
48
  )
49
+ from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
50
+ DecodeKVCacheOffloadManager,
51
+ )
47
52
  from sglang.srt.disaggregation.prefill import (
48
53
  PrefillBootstrapQueue,
49
54
  SchedulerDisaggregationPrefillMixin,
@@ -57,11 +62,6 @@ from sglang.srt.disaggregation.utils import (
57
62
  )
58
63
  from sglang.srt.distributed import get_pp_group, get_world_group
59
64
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
60
- from sglang.srt.hf_transformers_utils import (
61
- get_processor,
62
- get_tokenizer,
63
- get_tokenizer_from_processor,
64
- )
65
65
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
66
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
67
  from sglang.srt.layers.moe import initialize_moe_config
@@ -72,20 +72,26 @@ from sglang.srt.managers.io_struct import (
72
72
  ClearHiCacheReqInput,
73
73
  ClearHiCacheReqOutput,
74
74
  CloseSessionReqInput,
75
+ DestroyWeightsUpdateGroupReqInput,
75
76
  ExpertDistributionReq,
76
77
  ExpertDistributionReqOutput,
78
+ ExpertDistributionReqType,
77
79
  FlushCacheReqInput,
78
80
  FlushCacheReqOutput,
79
81
  FreezeGCReq,
80
82
  GetInternalStateReq,
81
83
  GetInternalStateReqOutput,
84
+ GetLoadReqInput,
85
+ GetLoadReqOutput,
82
86
  GetWeightsByNameReqInput,
83
87
  HealthCheckOutput,
88
+ InitWeightsSendGroupForRemoteInstanceReqInput,
89
+ InitWeightsSendGroupForRemoteInstanceReqOutput,
84
90
  InitWeightsUpdateGroupReqInput,
85
91
  LoadLoRAAdapterReqInput,
86
92
  LoadLoRAAdapterReqOutput,
87
93
  MultiTokenizerRegisterReq,
88
- MultiTokenizerWarpper,
94
+ MultiTokenizerWrapper,
89
95
  OpenSessionReqInput,
90
96
  OpenSessionReqOutput,
91
97
  ProfileReq,
@@ -93,6 +99,8 @@ from sglang.srt.managers.io_struct import (
93
99
  ResumeMemoryOccupationReqInput,
94
100
  RpcReqInput,
95
101
  RpcReqOutput,
102
+ SendWeightsToRemoteInstanceReqInput,
103
+ SendWeightsToRemoteInstanceReqOutput,
96
104
  SetInternalStateReq,
97
105
  SetInternalStateReqOutput,
98
106
  SlowDownReqInput,
@@ -106,10 +114,13 @@ from sglang.srt.managers.io_struct import (
106
114
  UpdateWeightsFromTensorReqInput,
107
115
  )
108
116
  from sglang.srt.managers.mm_utils import init_embedding_cache
117
+ from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
109
118
  from sglang.srt.managers.schedule_batch import (
110
119
  FINISH_ABORT,
120
+ ModelWorkerBatch,
111
121
  MultimodalInputs,
112
122
  Req,
123
+ RequestStage,
113
124
  ScheduleBatch,
114
125
  global_server_args_dict,
115
126
  )
@@ -132,19 +143,28 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
132
143
  SchedulerUpdateWeightsMixin,
133
144
  )
134
145
  from sglang.srt.managers.session_controller import Session
135
- from sglang.srt.managers.tp_worker import TpModelWorker
136
- from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
137
- from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
146
+ from sglang.srt.managers.utils import validate_input_length
138
147
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
139
148
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
140
- from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
141
149
  from sglang.srt.mem_cache.radix_cache import RadixCache
142
150
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
143
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
151
+ from sglang.srt.model_executor.forward_batch_info import (
152
+ ForwardBatch,
153
+ ForwardMode,
154
+ PPProxyTensors,
155
+ )
144
156
  from sglang.srt.parser.reasoning_parser import ReasoningParser
145
157
  from sglang.srt.server_args import PortArgs, ServerArgs
146
158
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
147
159
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
160
+ from sglang.srt.tracing.trace import (
161
+ process_tracing_init,
162
+ trace_set_proc_propagate_context,
163
+ trace_set_thread_info,
164
+ trace_slice_batch,
165
+ trace_slice_end,
166
+ trace_slice_start,
167
+ )
148
168
  from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
149
169
  from sglang.srt.utils import (
150
170
  DynamicGradMode,
@@ -155,9 +175,10 @@ from sglang.srt.utils import (
155
175
  freeze_gc,
156
176
  get_available_gpu_memory,
157
177
  get_bool_env_var,
178
+ get_int_env_var,
158
179
  get_zmq_socket,
159
- is_cpu,
160
180
  kill_itself_when_parent_died,
181
+ numa_bind_to_node,
161
182
  point_to_point_pyobj,
162
183
  pyspy_dump_schedulers,
163
184
  require_mlp_sync,
@@ -166,6 +187,11 @@ from sglang.srt.utils import (
166
187
  set_random_seed,
167
188
  suppress_other_loggers,
168
189
  )
190
+ from sglang.srt.utils.hf_transformers_utils import (
191
+ get_processor,
192
+ get_tokenizer,
193
+ get_tokenizer_from_processor,
194
+ )
169
195
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
170
196
 
171
197
  logger = logging.getLogger(__name__)
@@ -174,24 +200,67 @@ logger = logging.getLogger(__name__)
174
200
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
175
201
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
176
202
 
177
- _is_cpu = is_cpu()
178
-
179
203
 
180
204
  @dataclass
181
205
  class GenerationBatchResult:
182
- logits_output: Optional[LogitsProcessorOutput]
183
- pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
184
- next_token_ids: Optional[List[int]]
185
- extend_input_len_per_req: List[int]
186
- extend_logprob_start_len_per_req: List[int]
187
- bid: int
188
- can_run_cuda_graph: bool
206
+ logits_output: Optional[LogitsProcessorOutput] = None
207
+ pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
208
+ next_token_ids: Optional[torch.Tensor] = None
209
+ num_accepted_tokens: Optional[int] = None
210
+ can_run_cuda_graph: bool = False
211
+
212
+ # For output processing
213
+ extend_input_len_per_req: Optional[List[int]] = None
214
+ extend_logprob_start_len_per_req: Optional[List[int]] = None
215
+
216
+ # For overlap scheduling
217
+ copy_done: Optional[torch.cuda.Event] = None
218
+ delay_sample_launch: bool = False
219
+ forward_batch: Optional[ForwardBatch] = None
220
+ future_indices: Optional[FutureIndices] = None
221
+
222
+ def copy_to_cpu(self, return_logprob: bool = False):
223
+ """Copy tensors to CPU in overlap scheduling.
224
+ Only the tensors which are needed for processing results are copied,
225
+ e.g., next_token_ids, logits outputs
226
+ """
227
+ if return_logprob:
228
+ if self.logits_output.next_token_logits is not None:
229
+ self.logits_output.next_token_logits = (
230
+ self.logits_output.next_token_logits.to("cpu", non_blocking=True)
231
+ )
232
+ if self.logits_output.input_token_logprobs is not None:
233
+ self.logits_output.input_token_logprobs = (
234
+ self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
235
+ )
236
+ if self.logits_output.hidden_states is not None:
237
+ self.logits_output.hidden_states = self.logits_output.hidden_states.to(
238
+ "cpu", non_blocking=True
239
+ )
240
+ self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
241
+ self.copy_done.record()
242
+
243
+ @classmethod
244
+ def from_pp_proxy(
245
+ cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
246
+ ):
247
+ # TODO(lsyin): refactor PP and avoid using dict
248
+ proxy_dict = next_pp_outputs.tensors
249
+ return cls(
250
+ logits_output=logits_output,
251
+ pp_hidden_states_proxy_tensors=None,
252
+ next_token_ids=next_pp_outputs["next_token_ids"],
253
+ extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
254
+ extend_logprob_start_len_per_req=proxy_dict.get(
255
+ "extend_logprob_start_len_per_req", None
256
+ ),
257
+ can_run_cuda_graph=can_run_cuda_graph,
258
+ )
189
259
 
190
260
 
191
261
  @dataclass
192
262
  class EmbeddingBatchResult:
193
263
  embeddings: torch.Tensor
194
- bid: int
195
264
 
196
265
 
197
266
  class Scheduler(
@@ -204,6 +273,48 @@ class Scheduler(
204
273
  ):
205
274
  """A scheduler that manages a tensor parallel GPU worker."""
206
275
 
276
+ def launch_draft_worker(
277
+ self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
278
+ ):
279
+ if self.spec_algorithm.is_eagle():
280
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker
281
+
282
+ self.draft_worker = EAGLEWorker(
283
+ gpu_id=gpu_id,
284
+ tp_rank=tp_rank,
285
+ moe_ep_rank=moe_ep_rank,
286
+ server_args=server_args,
287
+ nccl_port=port_args.nccl_port,
288
+ target_worker=self.tp_worker,
289
+ dp_rank=dp_rank,
290
+ )
291
+ elif self.spec_algorithm.is_standalone():
292
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
293
+
294
+ self.draft_worker = StandaloneWorker(
295
+ gpu_id=gpu_id,
296
+ tp_rank=tp_rank,
297
+ moe_ep_rank=moe_ep_rank,
298
+ server_args=server_args,
299
+ nccl_port=port_args.nccl_port,
300
+ target_worker=self.tp_worker,
301
+ dp_rank=dp_rank,
302
+ )
303
+ elif self.spec_algorithm.is_ngram():
304
+ from sglang.srt.speculative.ngram_worker import NGRAMWorker
305
+
306
+ self.draft_worker = NGRAMWorker(
307
+ gpu_id=gpu_id,
308
+ tp_rank=tp_rank,
309
+ moe_ep_rank=moe_ep_rank,
310
+ server_args=server_args,
311
+ nccl_port=port_args.nccl_port,
312
+ target_worker=self.tp_worker,
313
+ dp_rank=dp_rank,
314
+ )
315
+ else:
316
+ self.draft_worker = None
317
+
207
318
  def __init__(
208
319
  self,
209
320
  server_args: ServerArgs,
@@ -213,7 +324,6 @@ class Scheduler(
213
324
  moe_ep_rank: int,
214
325
  pp_rank: int,
215
326
  dp_rank: Optional[int],
216
- dp_balance_meta: Optional[DPBalanceMeta] = None,
217
327
  ):
218
328
  # Parse args
219
329
  self.server_args = server_args
@@ -226,6 +336,13 @@ class Scheduler(
226
336
  self.pp_size = server_args.pp_size
227
337
  self.dp_size = server_args.dp_size
228
338
  self.schedule_policy = server_args.schedule_policy
339
+ self.enable_priority_scheduling = server_args.enable_priority_scheduling
340
+ self.schedule_low_priority_values_first = (
341
+ server_args.schedule_low_priority_values_first
342
+ )
343
+ self.priority_scheduling_preemption_threshold = (
344
+ server_args.priority_scheduling_preemption_threshold
345
+ )
229
346
  self.enable_lora = server_args.enable_lora
230
347
  self.max_loras_per_batch = server_args.max_loras_per_batch
231
348
  self.enable_overlap = not server_args.disable_overlap_schedule
@@ -234,7 +351,10 @@ class Scheduler(
234
351
  self.enable_metrics_for_all_schedulers = (
235
352
  server_args.enable_metrics_for_all_schedulers
236
353
  )
237
- self.enable_kv_cache_events = server_args.kv_events_config is not None
354
+ self.enable_kv_cache_events = bool(
355
+ server_args.kv_events_config and tp_rank == 0
356
+ )
357
+ self.enable_trace = server_args.enable_trace
238
358
  self.stream_interval = server_args.stream_interval
239
359
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
240
360
  server_args.speculative_algorithm
@@ -320,12 +440,10 @@ class Scheduler(
320
440
  logger.info("Overlap scheduler is disabled for embedding models.")
321
441
 
322
442
  # Launch a tensor parallel worker
323
- if self.enable_overlap:
324
- TpWorkerClass = TpModelWorkerClient
325
- else:
326
- TpWorkerClass = TpModelWorker
327
443
 
328
- self.tp_worker = TpWorkerClass(
444
+ from sglang.srt.managers.tp_worker import TpModelWorker
445
+
446
+ self.tp_worker = TpModelWorker(
329
447
  server_args=server_args,
330
448
  gpu_id=gpu_id,
331
449
  tp_rank=tp_rank,
@@ -336,20 +454,15 @@ class Scheduler(
336
454
  )
337
455
 
338
456
  # Launch a draft worker for speculative decoding
339
- if self.spec_algorithm.is_eagle():
340
- from sglang.srt.speculative.eagle_worker import EAGLEWorker
457
+ self.launch_draft_worker(
458
+ gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
459
+ )
341
460
 
342
- self.draft_worker = EAGLEWorker(
343
- gpu_id=gpu_id,
344
- tp_rank=tp_rank,
345
- moe_ep_rank=moe_ep_rank,
346
- server_args=server_args,
347
- nccl_port=port_args.nccl_port,
348
- target_worker=self.tp_worker,
349
- dp_rank=dp_rank,
350
- )
461
+ # Dispatch the model worker
462
+ if self.spec_algorithm.is_none():
463
+ self.model_worker = self.tp_worker
351
464
  else:
352
- self.draft_worker = None
465
+ self.model_worker = self.draft_worker
353
466
 
354
467
  # Get token and memory info from the model worker
355
468
  (
@@ -366,8 +479,8 @@ class Scheduler(
366
479
  _,
367
480
  _,
368
481
  ) = self.tp_worker.get_worker_info()
369
- if global_server_args_dict["max_micro_batch_size"] is None:
370
- global_server_args_dict["max_micro_batch_size"] = max(
482
+ if global_server_args_dict["pp_max_micro_batch_size"] is None:
483
+ global_server_args_dict["pp_max_micro_batch_size"] = max(
371
484
  self.max_running_requests // server_args.pp_size, 1
372
485
  )
373
486
 
@@ -401,7 +514,7 @@ class Scheduler(
401
514
  f"max_prefill_tokens={self.max_prefill_tokens}, "
402
515
  f"max_running_requests={self.max_running_requests}, "
403
516
  f"context_len={self.model_config.context_len}, "
404
- f"available_gpu_mem={avail_mem:.2f} GB"
517
+ f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
405
518
  )
406
519
 
407
520
  # Init memory pool and cache
@@ -427,9 +540,11 @@ class Scheduler(
427
540
  self.kv_transfer_speed_gb_s: float = 0.0
428
541
  self.kv_transfer_latency_ms: float = 0.0
429
542
  self.sessions: Dict[str, Session] = {}
430
- self.current_stream = torch.get_device_module(self.device).current_stream()
543
+ self.default_stream: CudaStream = torch.get_device_module(
544
+ self.device
545
+ ).current_stream()
431
546
  if self.device == "cpu":
432
- self.current_stream.synchronize = lambda: None # No-op for CPU
547
+ self.default_stream.synchronize = lambda: None # No-op for CPU
433
548
  self.forward_sleep_time = None
434
549
 
435
550
  # Init chunked prefill
@@ -458,7 +573,12 @@ class Scheduler(
458
573
  self.schedule_policy,
459
574
  self.tree_cache,
460
575
  self.enable_hierarchical_cache,
576
+ self.enable_priority_scheduling,
577
+ self.schedule_low_priority_values_first,
461
578
  )
579
+ # Enable preemption for priority scheduling.
580
+ self.try_preemption = self.enable_priority_scheduling
581
+
462
582
  assert (
463
583
  server_args.schedule_conservativeness >= 0
464
584
  ), "Invalid schedule_conservativeness"
@@ -488,7 +608,7 @@ class Scheduler(
488
608
  enable=server_args.enable_memory_saver
489
609
  )
490
610
  self.offload_tags = set()
491
- self.init_profier()
611
+ self.init_profiler()
492
612
 
493
613
  self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
494
614
  self.input_blocker = (
@@ -499,8 +619,9 @@ class Scheduler(
499
619
 
500
620
  # Init metrics stats
501
621
  self.init_metrics(tp_rank, pp_rank, dp_rank)
502
- self.init_kv_events(server_args.kv_events_config)
503
- self.init_dp_balance(dp_balance_meta)
622
+
623
+ if self.enable_kv_cache_events:
624
+ self.init_kv_events(server_args.kv_events_config)
504
625
 
505
626
  # Init disaggregation
506
627
  self.disaggregation_mode = DisaggregationMode(
@@ -511,6 +632,12 @@ class Scheduler(
511
632
  if get_bool_env_var("SGLANG_GC_LOG"):
512
633
  configure_gc_logger()
513
634
 
635
+ # Init prefill kv split size when deterministic inference is enabled with various attention backends
636
+ self.init_deterministic_inference_config()
637
+
638
+ # Init overlap
639
+ self.init_overlap()
640
+
514
641
  # Init request dispatcher
515
642
  self._request_dispatcher = TypeBasedDispatcher(
516
643
  [
@@ -525,6 +652,15 @@ class Scheduler(
525
652
  (CloseSessionReqInput, self.close_session),
526
653
  (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
527
654
  (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
655
+ (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
656
+ (
657
+ InitWeightsSendGroupForRemoteInstanceReqInput,
658
+ self.init_weights_send_group_for_remote_instance,
659
+ ),
660
+ (
661
+ SendWeightsToRemoteInstanceReqInput,
662
+ self.send_weights_to_remote_instance,
663
+ ),
528
664
  (
529
665
  UpdateWeightsFromDistributedReqInput,
530
666
  self.update_weights_from_distributed,
@@ -543,9 +679,27 @@ class Scheduler(
543
679
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
544
680
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
545
681
  (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
682
+ (GetLoadReqInput, self.get_load),
546
683
  ]
547
684
  )
548
685
 
686
+ def init_deterministic_inference_config(self):
687
+ """Initialize deterministic inference configuration for different attention backends."""
688
+ if not self.server_args.enable_deterministic_inference:
689
+ self.truncation_align_size = None
690
+ return
691
+
692
+ backend_sizes = {
693
+ "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
694
+ "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
695
+ }
696
+ env_var, default_size = backend_sizes.get(
697
+ self.server_args.attention_backend, (None, None)
698
+ )
699
+ self.truncation_align_size = (
700
+ get_int_env_var(env_var, default_size) if env_var else None
701
+ )
702
+
549
703
  def init_tokenizer(self):
550
704
  server_args = self.server_args
551
705
  self.is_generation = self.model_config.is_generation
@@ -617,15 +771,18 @@ class Scheduler(
617
771
  else self.tp_cpu_group
618
772
  ),
619
773
  page_size=self.page_size,
774
+ eviction_policy=server_args.radix_eviction_policy,
620
775
  hicache_ratio=server_args.hicache_ratio,
621
776
  hicache_size=server_args.hicache_size,
622
777
  hicache_write_policy=server_args.hicache_write_policy,
623
778
  hicache_io_backend=server_args.hicache_io_backend,
624
779
  hicache_mem_layout=server_args.hicache_mem_layout,
780
+ enable_metrics=self.enable_metrics,
625
781
  hicache_storage_backend=server_args.hicache_storage_backend,
626
782
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
627
783
  model_name=server_args.served_model_name,
628
784
  storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
785
+ is_eagle=self.spec_algorithm.is_eagle(),
629
786
  )
630
787
  self.tp_worker.register_hicache_layer_transfer_counter(
631
788
  self.tree_cache.cache_controller.layer_done_counter
@@ -640,19 +797,23 @@ class Scheduler(
640
797
  sliding_window_size=self.sliding_window_size,
641
798
  page_size=self.page_size,
642
799
  disable=server_args.disable_radix_cache,
800
+ is_eagle=self.spec_algorithm.is_eagle(),
643
801
  )
644
- elif self.enable_lora:
645
- assert (
646
- not self.enable_hierarchical_cache
647
- ), "LoRA radix cache doesn't support hierarchical cache"
648
- assert (
649
- self.schedule_policy == "fcfs"
650
- ), "LoRA radix cache only supports FCFS policy"
651
- self.tree_cache = LoRARadixCache(
802
+ elif server_args.enable_lmcache:
803
+ from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
804
+ LMCRadixCache,
805
+ )
806
+
807
+ self.tree_cache = LMCRadixCache(
652
808
  req_to_token_pool=self.req_to_token_pool,
653
809
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
654
810
  page_size=self.page_size,
655
811
  disable=server_args.disable_radix_cache,
812
+ model_config=self.model_config,
813
+ tp_size=self.tp_size,
814
+ rank=self.tp_rank,
815
+ tp_group=self.tp_group,
816
+ eviction_policy=server_args.radix_eviction_policy,
656
817
  )
657
818
  else:
658
819
  self.tree_cache = RadixCache(
@@ -661,16 +822,36 @@ class Scheduler(
661
822
  page_size=self.page_size,
662
823
  disable=server_args.disable_radix_cache,
663
824
  enable_kv_cache_events=self.enable_kv_cache_events,
825
+ eviction_policy=server_args.radix_eviction_policy,
826
+ is_eagle=self.spec_algorithm.is_eagle(),
664
827
  )
665
828
 
829
+ if (
830
+ server_args.disaggregation_mode == "decode"
831
+ and server_args.disaggregation_decode_enable_offload_kvcache
832
+ ):
833
+ self.decode_offload_manager = DecodeKVCacheOffloadManager(
834
+ req_to_token_pool=self.req_to_token_pool,
835
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
836
+ tp_group=(
837
+ self.attn_tp_cpu_group
838
+ if self.server_args.enable_dp_attention
839
+ else self.tp_cpu_group
840
+ ),
841
+ tree_cache=self.tree_cache,
842
+ server_args=self.server_args,
843
+ )
844
+ else:
845
+ self.decode_offload_manager = None
846
+
666
847
  self.decode_mem_cache_buf_multiplier = (
667
848
  1
668
849
  if self.spec_algorithm.is_none()
669
850
  else (
670
851
  server_args.speculative_num_draft_tokens
671
852
  + (
672
- server_args.speculative_eagle_topk
673
- * server_args.speculative_num_steps
853
+ (server_args.speculative_eagle_topk or 1)
854
+ * (server_args.speculative_num_steps or 1)
674
855
  )
675
856
  )
676
857
  )
@@ -693,7 +874,7 @@ class Scheduler(
693
874
  self.disagg_metadata_buffers = MetadataBuffers(
694
875
  buffer_size,
695
876
  hidden_size=self.model_config.hf_text_config.hidden_size,
696
- dtype=self.model_config.dtype,
877
+ hidden_states_dtype=self.model_config.dtype,
697
878
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
698
879
  )
699
880
 
@@ -713,7 +894,7 @@ class Scheduler(
713
894
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
714
895
  draft_token_to_kv_pool=(
715
896
  None
716
- if self.draft_worker is None
897
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
717
898
  else self.draft_worker.model_runner.token_to_kv_pool
718
899
  ),
719
900
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -742,7 +923,7 @@ class Scheduler(
742
923
  self.disagg_metadata_buffers = MetadataBuffers(
743
924
  buffer_size,
744
925
  hidden_size=self.model_config.hf_text_config.hidden_size,
745
- dtype=self.model_config.dtype,
926
+ hidden_states_dtype=self.model_config.dtype,
746
927
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
747
928
  )
748
929
 
@@ -750,7 +931,7 @@ class Scheduler(
750
931
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
751
932
  draft_token_to_kv_pool=(
752
933
  None
753
- if self.draft_worker is None
934
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
754
935
  else self.draft_worker.model_runner.token_to_kv_pool
755
936
  ),
756
937
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -771,6 +952,32 @@ class Scheduler(
771
952
  # The prefill requests that are in the middle of kv sending
772
953
  self.disagg_prefill_inflight_queue: List[Req] = []
773
954
 
955
+ def init_overlap(self):
956
+ if not self.enable_overlap:
957
+ return
958
+
959
+ self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
960
+ self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
961
+ self.device
962
+ ).stream(self.forward_stream)
963
+ self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
964
+ self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
965
+ self.device
966
+ ).stream(self.copy_stream)
967
+
968
+ self.future_map = FutureMap(self.max_running_requests, self.device)
969
+ self.batch_record_buf = [None] * 2
970
+ self.batch_record_ct = 0
971
+
972
+ def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
973
+ # FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
974
+ # NOTE: More Reliable: record all tensors into the forward stream
975
+ # NOTE: - for all future tensors, we shall always read from future map
976
+ # - for all non-future tensors (produced only by schedule stream),
977
+ # we shall keep its reference not being release during all the forwarding pass
978
+ self.batch_record_ct = (self.batch_record_ct + 1) % 2
979
+ self.batch_record_buf[self.batch_record_ct] = model_worker_batch
980
+
774
981
  def init_moe_config(self):
775
982
  if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
776
983
  initialize_moe_config(self.server_args)
@@ -797,9 +1004,11 @@ class Scheduler(
797
1004
  @DynamicGradMode()
798
1005
  def event_loop_overlap(self):
799
1006
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
800
- self.result_queue = deque()
1007
+ self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
801
1008
 
802
1009
  while True:
1010
+ self.launch_last_batch_sample_if_needed()
1011
+
803
1012
  recv_reqs = self.recv_requests()
804
1013
  self.process_input_requests(recv_reqs)
805
1014
 
@@ -807,30 +1016,13 @@ class Scheduler(
807
1016
  self.cur_batch = batch
808
1017
 
809
1018
  if batch:
810
- batch.launch_done = threading.Event()
811
1019
  result = self.run_batch(batch)
812
1020
  self.result_queue.append((batch.copy(), result))
813
1021
 
814
- if self.last_batch is None:
815
- # Create a dummy first batch to start the pipeline for overlap schedule.
816
- # It is now used for triggering the sampling_info_done event.
817
- tmp_batch = ScheduleBatch(
818
- reqs=None,
819
- forward_mode=ForwardMode.DUMMY_FIRST,
820
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
821
- )
822
- self.process_batch_result(tmp_batch, None, batch.launch_done)
823
-
824
1022
  if self.last_batch:
825
1023
  # Process the results of the last batch
826
1024
  tmp_batch, tmp_result = self.result_queue.popleft()
827
- tmp_batch.next_batch_sampling_info = (
828
- self.tp_worker.cur_sampling_info if batch else None
829
- )
830
- # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
831
- self.process_batch_result(
832
- tmp_batch, tmp_result, batch.launch_done if batch else None
833
- )
1025
+ self.process_batch_result(tmp_batch, tmp_result)
834
1026
  elif batch is None:
835
1027
  # When the server is idle, do self-check and re-init some states
836
1028
  self.self_check_during_idle()
@@ -845,7 +1037,6 @@ class Scheduler(
845
1037
  self.running_mbs = [
846
1038
  ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
847
1039
  ]
848
- bids = [None] * self.pp_size
849
1040
  pp_outputs: Optional[PPProxyTensors] = None
850
1041
  while True:
851
1042
  server_is_idle = True
@@ -866,10 +1057,7 @@ class Scheduler(
866
1057
  # (last rank) send the outputs to the next step
867
1058
  if self.pp_group.is_last_rank:
868
1059
  if self.cur_batch:
869
- next_token_ids, bids[mb_id] = (
870
- result.next_token_ids,
871
- result.bid,
872
- )
1060
+ next_token_ids = result.next_token_ids
873
1061
  if self.cur_batch.return_logprob:
874
1062
  pp_outputs = PPProxyTensors(
875
1063
  {
@@ -917,17 +1105,10 @@ class Scheduler(
917
1105
  logits_output = LogitsProcessorOutput(**logits_output_args)
918
1106
  else:
919
1107
  logits_output = None
920
- output_result = GenerationBatchResult(
1108
+
1109
+ output_result = GenerationBatchResult.from_pp_proxy(
921
1110
  logits_output=logits_output,
922
- pp_hidden_states_proxy_tensors=None,
923
- next_token_ids=next_pp_outputs["next_token_ids"],
924
- extend_input_len_per_req=next_pp_outputs.tensors.get(
925
- "extend_input_len_per_req", None
926
- ),
927
- extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
928
- "extend_logprob_start_len_per_req", None
929
- ),
930
- bid=bids[next_mb_id],
1111
+ next_pp_outputs=next_pp_outputs,
931
1112
  can_run_cuda_graph=result.can_run_cuda_graph,
932
1113
  )
933
1114
  self.process_batch_result(mbs[next_mb_id], output_result)
@@ -935,8 +1116,6 @@ class Scheduler(
935
1116
 
936
1117
  # (not last rank)
937
1118
  if not self.pp_group.is_last_rank:
938
- if self.cur_batch:
939
- bids[mb_id] = result.bid
940
1119
  # carry the outputs to the next stage
941
1120
  # send the outputs from the last round to let the next stage worker run post processing
942
1121
  if pp_outputs:
@@ -958,8 +1137,10 @@ class Scheduler(
958
1137
 
959
1138
  # send out proxy tensors to the next stage
960
1139
  if self.cur_batch:
1140
+ # FIXME(lsyin): remove this assert
1141
+ assert result.pp_hidden_states_proxy_tensors.tensors is not None
961
1142
  self.pp_group.send_tensor_dict(
962
- result.pp_hidden_states_proxy_tensors,
1143
+ result.pp_hidden_states_proxy_tensors.tensors,
963
1144
  all_gather_group=self.attn_tp_group,
964
1145
  )
965
1146
 
@@ -1069,6 +1250,15 @@ class Scheduler(
1069
1250
  self.tp_cpu_group,
1070
1251
  src=self.tp_group.ranks[0],
1071
1252
  )
1253
+
1254
+ if self.enable_trace:
1255
+ for req in recv_reqs:
1256
+ if isinstance(
1257
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1258
+ ):
1259
+ trace_set_proc_propagate_context(req.rid, req.trace_context)
1260
+ trace_slice_start("", req.rid, anonymous=True)
1261
+
1072
1262
  return recv_reqs
1073
1263
 
1074
1264
  def process_input_requests(self, recv_reqs: List):
@@ -1082,27 +1272,13 @@ class Scheduler(
1082
1272
  self.return_health_check_ct += 1
1083
1273
  continue
1084
1274
 
1085
- # If it is a work request, accept or reject the request based on the request queue size.
1086
- if is_work_request(recv_req):
1087
- if len(self.waiting_queue) + 1 > self.max_queued_requests:
1088
- abort_req = AbortReq(
1089
- recv_req.rid,
1090
- finished_reason={
1091
- "type": "abort",
1092
- "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1093
- "message": "The request queue is full.",
1094
- },
1095
- )
1096
- self.send_to_tokenizer.send_pyobj(abort_req)
1097
- continue
1098
-
1099
- # If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
1100
- if isinstance(recv_req, MultiTokenizerWarpper):
1275
+ # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
1276
+ if isinstance(recv_req, MultiTokenizerWrapper):
1101
1277
  worker_id = recv_req.worker_id
1102
1278
  recv_req = recv_req.obj
1103
1279
  output = self._request_dispatcher(recv_req)
1104
1280
  if output is not None:
1105
- output = MultiTokenizerWarpper(worker_id, output)
1281
+ output = MultiTokenizerWrapper(worker_id, output)
1106
1282
  self.send_to_tokenizer.send_pyobj(output)
1107
1283
  continue
1108
1284
 
@@ -1114,12 +1290,20 @@ class Scheduler(
1114
1290
  else:
1115
1291
  self.send_to_tokenizer.send_pyobj(output)
1116
1292
 
1293
+ def init_req_max_new_tokens(self, req):
1294
+ req.sampling_params.max_new_tokens = min(
1295
+ (
1296
+ req.sampling_params.max_new_tokens
1297
+ if req.sampling_params.max_new_tokens is not None
1298
+ else 1 << 30
1299
+ ),
1300
+ self.max_req_len - len(req.origin_input_ids) - 1,
1301
+ )
1302
+
1117
1303
  def handle_generate_request(
1118
1304
  self,
1119
1305
  recv_req: TokenizedGenerateReqInput,
1120
1306
  ):
1121
- self.maybe_update_dp_balance_data(recv_req)
1122
-
1123
1307
  # Create a new request
1124
1308
  if (
1125
1309
  recv_req.session_params is None
@@ -1153,8 +1337,13 @@ class Scheduler(
1153
1337
  bootstrap_host=recv_req.bootstrap_host,
1154
1338
  bootstrap_port=recv_req.bootstrap_port,
1155
1339
  bootstrap_room=recv_req.bootstrap_room,
1340
+ disagg_mode=self.disaggregation_mode,
1156
1341
  data_parallel_rank=recv_req.data_parallel_rank,
1157
1342
  vocab_size=self.model_config.vocab_size,
1343
+ priority=recv_req.priority,
1344
+ metrics_collector=(
1345
+ self.metrics_collector if self.enable_metrics else None
1346
+ ),
1158
1347
  )
1159
1348
  req.tokenizer = self.tokenizer
1160
1349
 
@@ -1177,6 +1366,7 @@ class Scheduler(
1177
1366
  req.set_finish_with_abort(
1178
1367
  f"Invalid request: session id {recv_req.session_params.id} does not exist"
1179
1368
  )
1369
+ self.init_req_max_new_tokens(req)
1180
1370
  self._add_request_to_queue(req)
1181
1371
  return
1182
1372
  else:
@@ -1184,6 +1374,7 @@ class Scheduler(
1184
1374
  session = self.sessions[recv_req.session_params.id]
1185
1375
  req = session.create_req(recv_req, self.tokenizer)
1186
1376
  if isinstance(req.finished_reason, FINISH_ABORT):
1377
+ self.init_req_max_new_tokens(req)
1187
1378
  self._add_request_to_queue(req)
1188
1379
  return
1189
1380
 
@@ -1203,9 +1394,13 @@ class Scheduler(
1203
1394
  f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
1204
1395
  )
1205
1396
  )
1397
+ self.init_req_max_new_tokens(req)
1206
1398
  self._add_request_to_queue(req)
1207
1399
  return
1208
1400
 
1401
+ # initialize before returning
1402
+ self.init_req_max_new_tokens(req)
1403
+
1209
1404
  # Validate prompt length
1210
1405
  error_msg = validate_input_length(
1211
1406
  req,
@@ -1220,26 +1415,25 @@ class Scheduler(
1220
1415
  # Copy more attributes
1221
1416
  if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1222
1417
  # By default, only return the logprobs for output tokens
1223
- req.logprob_start_len = len(req.origin_input_ids) - 1
1418
+ # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
1419
+ # to skip input logprob computation entirely
1420
+ if req.is_prefill_only:
1421
+ req.logprob_start_len = len(req.origin_input_ids)
1422
+ else:
1423
+ # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
1424
+ req.logprob_start_len = len(req.origin_input_ids) - 1
1224
1425
  else:
1225
1426
  req.logprob_start_len = recv_req.logprob_start_len
1226
1427
 
1227
- if req.logprob_start_len >= len(req.origin_input_ids):
1428
+ if not req.is_prefill_only and req.logprob_start_len >= len(
1429
+ req.origin_input_ids
1430
+ ):
1228
1431
  error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1229
1432
  req.logprob_start_len = len(req.origin_input_ids) - 1
1230
1433
  req.set_finish_with_abort(error_msg)
1231
1434
  self._add_request_to_queue(req)
1232
1435
  return
1233
1436
 
1234
- req.sampling_params.max_new_tokens = min(
1235
- (
1236
- req.sampling_params.max_new_tokens
1237
- if req.sampling_params.max_new_tokens is not None
1238
- else 1 << 30
1239
- ),
1240
- self.max_req_len - len(req.origin_input_ids) - 1,
1241
- )
1242
-
1243
1437
  # Init grammar cache for this request
1244
1438
  add_to_grammar_queue = False
1245
1439
  if (
@@ -1270,7 +1464,6 @@ class Scheduler(
1270
1464
  req.set_finish_with_abort(error_msg)
1271
1465
 
1272
1466
  if add_to_grammar_queue:
1273
- req.queue_time_start = time.perf_counter()
1274
1467
  self.grammar_queue.append(req)
1275
1468
  else:
1276
1469
  self._add_request_to_queue(req)
@@ -1286,19 +1479,6 @@ class Scheduler(
1286
1479
  for tokenized_req in recv_req:
1287
1480
  self.handle_generate_request(tokenized_req)
1288
1481
 
1289
- def _add_request_to_queue(self, req: Req):
1290
- req.queue_time_start = time.perf_counter()
1291
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1292
- self._prefetch_kvcache(req)
1293
- self.disagg_prefill_bootstrap_queue.add(
1294
- req, self.model_config.num_key_value_heads
1295
- )
1296
- elif self.disaggregation_mode == DisaggregationMode.DECODE:
1297
- self.disagg_decode_prealloc_queue.add(req)
1298
- else:
1299
- self._prefetch_kvcache(req)
1300
- self.waiting_queue.append(req)
1301
-
1302
1482
  def _prefetch_kvcache(self, req: Req):
1303
1483
  if self.enable_hicache_storage:
1304
1484
  req.init_next_round_input(self.tree_cache)
@@ -1312,16 +1492,87 @@ class Scheduler(
1312
1492
  req.rid, req.last_host_node, new_input_tokens, last_hash
1313
1493
  )
1314
1494
 
1315
- def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1316
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1317
- self.disagg_prefill_bootstrap_queue.extend(
1318
- reqs, self.model_config.num_key_value_heads
1495
+ def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
1496
+ if self.disaggregation_mode == DisaggregationMode.NULL:
1497
+ self._set_or_validate_priority(req)
1498
+ if self._abort_on_queued_limit(req):
1499
+ return
1500
+ self._prefetch_kvcache(req)
1501
+ self.waiting_queue.append(req)
1502
+ req.time_stats.wait_queue_entry_time = time.perf_counter()
1503
+ trace_slice_end("process req", req.rid, auto_next_anon=True)
1504
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
1505
+ self._prefetch_kvcache(req)
1506
+ self.disagg_prefill_bootstrap_queue.add(
1507
+ req, self.model_config.num_key_value_heads
1319
1508
  )
1509
+ req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
1320
1510
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1321
- # If this is a decode server, we put the request to the decode pending prealloc queue
1322
- self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
1511
+ self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
1512
+ if not is_retracted:
1513
+ req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
1323
1514
  else:
1324
- self.waiting_queue.extend(reqs)
1515
+ raise ValueError(f"Invalid {self.disaggregation_mode=}")
1516
+
1517
+ def _set_or_validate_priority(self, req: Req):
1518
+ """Set the default priority value, or abort the request based on the priority scheduling mode."""
1519
+ if self.enable_priority_scheduling and req.priority is None:
1520
+ if self.schedule_low_priority_values_first:
1521
+ req.priority = sys.maxsize
1522
+ else:
1523
+ req.priority = -sys.maxsize - 1
1524
+ elif not self.enable_priority_scheduling and req.priority is not None:
1525
+ abort_req = AbortReq(
1526
+ finished_reason={
1527
+ "type": "abort",
1528
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1529
+ "message": "Using priority is disabled for this server. Please send a new request without a priority.",
1530
+ },
1531
+ rid=req.rid,
1532
+ )
1533
+ self.send_to_tokenizer.send_pyobj(abort_req)
1534
+
1535
+ def _abort_on_queued_limit(self, recv_req: Req) -> bool:
1536
+ """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
1537
+ if (
1538
+ self.max_queued_requests is None
1539
+ or len(self.waiting_queue) + 1 <= self.max_queued_requests
1540
+ ):
1541
+ return False
1542
+
1543
+ # Reject the incoming request by default.
1544
+ req_to_abort = recv_req
1545
+ message = "The request queue is full."
1546
+ if self.enable_priority_scheduling:
1547
+ # With priority scheduling, consider aboritng an existing request based on the priority.
1548
+ # direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
1549
+ # max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
1550
+ # Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
1551
+ direction = 1 if self.schedule_low_priority_values_first else -1
1552
+ key_fn = lambda item: (
1553
+ direction * item[1].priority,
1554
+ item[1].time_stats.wait_queue_entry_time,
1555
+ )
1556
+ idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
1557
+ abort_existing_req = (
1558
+ direction * recv_req.priority < direction * candidate_req.priority
1559
+ )
1560
+ if abort_existing_req:
1561
+ self.waiting_queue.pop(idx)
1562
+ req_to_abort = candidate_req
1563
+ message = "The request is aborted by a higher priority request."
1564
+
1565
+ self.send_to_tokenizer.send_pyobj(
1566
+ AbortReq(
1567
+ finished_reason={
1568
+ "type": "abort",
1569
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1570
+ "message": message,
1571
+ },
1572
+ rid=req_to_abort.rid,
1573
+ )
1574
+ )
1575
+ return req_to_abort.rid == recv_req.rid
1325
1576
 
1326
1577
  def handle_embedding_request(
1327
1578
  self,
@@ -1333,6 +1584,7 @@ class Scheduler(
1333
1584
  recv_req.input_ids,
1334
1585
  recv_req.sampling_params,
1335
1586
  token_type_ids=recv_req.token_type_ids,
1587
+ priority=recv_req.priority,
1336
1588
  )
1337
1589
  req.tokenizer = self.tokenizer
1338
1590
 
@@ -1409,9 +1661,11 @@ class Scheduler(
1409
1661
  _, _, available_size, evictable_size = self._get_token_info()
1410
1662
  protected_size = self.tree_cache.protected_size()
1411
1663
  memory_leak = (available_size + evictable_size) != (
1664
+ # self.max_total_num_tokens
1665
+ # if not self.enable_hierarchical_cache
1666
+ # else self.max_total_num_tokens - protected_size
1412
1667
  self.max_total_num_tokens
1413
- if not self.enable_hierarchical_cache
1414
- else self.max_total_num_tokens - protected_size
1668
+ - protected_size
1415
1669
  )
1416
1670
  token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
1417
1671
 
@@ -1462,6 +1716,20 @@ class Scheduler(
1462
1716
  self.stats.gen_throughput = 0
1463
1717
  self.stats.num_queue_reqs = len(self.waiting_queue)
1464
1718
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1719
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1720
+ self.stats.num_prefill_prealloc_queue_reqs = len(
1721
+ self.disagg_prefill_bootstrap_queue.queue
1722
+ )
1723
+ self.stats.num_prefill_inflight_queue_reqs = len(
1724
+ self.disagg_prefill_inflight_queue
1725
+ )
1726
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1727
+ self.stats.num_decode_prealloc_queue_reqs = len(
1728
+ self.disagg_decode_prealloc_queue.queue
1729
+ )
1730
+ self.stats.num_decode_transfer_queue_reqs = len(
1731
+ self.disagg_decode_transfer_queue.queue
1732
+ )
1465
1733
  self.metrics_collector.log_stats(self.stats)
1466
1734
  self._publish_kv_events()
1467
1735
 
@@ -1509,7 +1777,12 @@ class Scheduler(
1509
1777
  chunked_req_to_exclude.add(self.chunked_req)
1510
1778
  self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1511
1779
  # chunked request keeps its rid but will get a new req_pool_idx
1512
- self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1780
+ if self.tp_worker.worker.model_runner.mambaish_config is not None:
1781
+ self.req_to_token_pool.free(
1782
+ self.chunked_req.req_pool_idx, free_mamba_cache=False
1783
+ )
1784
+ else:
1785
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1513
1786
  if self.last_batch and self.last_batch.forward_mode.is_extend():
1514
1787
  if self.last_batch.chunked_req is not None:
1515
1788
  # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
@@ -1556,13 +1829,12 @@ class Scheduler(
1556
1829
 
1557
1830
  # Handle DP attention
1558
1831
  if need_dp_attn_preparation:
1559
- self.maybe_handle_dp_balance_data()
1560
1832
  ret = self.prepare_mlp_sync_batch(ret)
1561
1833
 
1562
1834
  return ret
1563
1835
 
1564
1836
  def get_num_allocatable_reqs(self, running_bs):
1565
- res = global_server_args_dict["max_micro_batch_size"] - running_bs
1837
+ res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs
1566
1838
  if self.pp_size > 1:
1567
1839
  res = min(res, self.req_to_token_pool.available_size())
1568
1840
  return res
@@ -1572,6 +1844,10 @@ class Scheduler(
1572
1844
  if self.grammar_queue:
1573
1845
  self.move_ready_grammar_requests()
1574
1846
 
1847
+ if self.try_preemption:
1848
+ # Reset batch_is_full to try preemption with a prefill adder.
1849
+ self.running_batch.batch_is_full = False
1850
+
1575
1851
  # Handle the cases where prefill is not allowed
1576
1852
  if (
1577
1853
  self.running_batch.batch_is_full or len(self.waiting_queue) == 0
@@ -1584,7 +1860,11 @@ class Scheduler(
1584
1860
  # as the space for the chunked request has just been released.
1585
1861
  # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
1586
1862
  # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1587
- if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
1863
+ if (
1864
+ self.get_num_allocatable_reqs(running_bs) <= 0
1865
+ and not self.chunked_req
1866
+ and not self.try_preemption
1867
+ ):
1588
1868
  self.running_batch.batch_is_full = True
1589
1869
  return None
1590
1870
 
@@ -1604,6 +1884,7 @@ class Scheduler(
1604
1884
  self.max_prefill_tokens,
1605
1885
  self.chunked_prefill_size,
1606
1886
  running_bs if self.is_mixed_chunk else 0,
1887
+ self.priority_scheduling_preemption_threshold,
1607
1888
  )
1608
1889
 
1609
1890
  if self.chunked_req is not None:
@@ -1624,15 +1905,19 @@ class Scheduler(
1624
1905
  self.running_batch.batch_is_full = True
1625
1906
  break
1626
1907
 
1908
+ running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
1627
1909
  if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
1628
1910
  self.running_batch.batch_is_full = True
1629
- break
1630
-
1631
1911
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1632
1912
  # In prefill mode, prealloc queue and transfer queue can also take memory,
1633
1913
  # so we need to check if the available size for the actual available size.
1634
1914
  if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
1635
1915
  self.running_batch.batch_is_full = True
1916
+
1917
+ if self.running_batch.batch_is_full:
1918
+ if not self.try_preemption:
1919
+ break
1920
+ if not adder.preempt_to_schedule(req, self.server_args):
1636
1921
  break
1637
1922
 
1638
1923
  if self.enable_hicache_storage:
@@ -1642,7 +1927,11 @@ class Scheduler(
1642
1927
  continue
1643
1928
 
1644
1929
  req.init_next_round_input(self.tree_cache)
1645
- res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1930
+ res = adder.add_one_req(
1931
+ req,
1932
+ has_chunked_req=(self.chunked_req is not None),
1933
+ truncation_align_size=self.truncation_align_size,
1934
+ )
1646
1935
 
1647
1936
  if res != AddReqResult.CONTINUE:
1648
1937
  if res == AddReqResult.NO_TOKEN:
@@ -1663,11 +1952,14 @@ class Scheduler(
1663
1952
  if self.enable_metrics:
1664
1953
  # only record queue time when enable_metrics is True to avoid overhead
1665
1954
  for req in can_run_list:
1666
- req.queue_time_end = time.perf_counter()
1955
+ req.add_latency(RequestStage.PREFILL_WAITING)
1667
1956
 
1668
1957
  self.waiting_queue = [
1669
1958
  x for x in self.waiting_queue if x not in set(can_run_list)
1670
1959
  ]
1960
+ if adder.preempt_list:
1961
+ for req in adder.preempt_list:
1962
+ self._add_request_to_queue(req)
1671
1963
 
1672
1964
  if adder.new_chunked_req is not None:
1673
1965
  assert self.chunked_req is None
@@ -1678,7 +1970,16 @@ class Scheduler(
1678
1970
 
1679
1971
  # Print stats
1680
1972
  if self.current_scheduler_metrics_enabled():
1681
- self.log_prefill_stats(adder, can_run_list, running_bs)
1973
+ self.log_prefill_stats(adder, can_run_list, running_bs, 0)
1974
+
1975
+ for req in can_run_list:
1976
+ if req.time_stats.forward_entry_time == 0:
1977
+ # Avoid update chunked request many times
1978
+ req.time_stats.forward_entry_time = time.perf_counter()
1979
+ if self.enable_metrics:
1980
+ self.metrics_collector.observe_queue_time(
1981
+ req.time_stats.get_queueing_time(),
1982
+ )
1682
1983
 
1683
1984
  # Create a new batch
1684
1985
  new_batch = ScheduleBatch.init_new(
@@ -1733,19 +2034,25 @@ class Scheduler(
1733
2034
  TEST_RETRACT and batch.batch_size() > 10
1734
2035
  ):
1735
2036
  old_ratio = self.new_token_ratio
1736
-
1737
- retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1738
- num_retracted_reqs = len(retracted_reqs)
2037
+ retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
2038
+ self.server_args
2039
+ )
2040
+ self.num_retracted_reqs = len(retracted_reqs)
1739
2041
  self.new_token_ratio = new_token_ratio
2042
+ for req in reqs_to_abort:
2043
+ self.send_to_tokenizer.send_pyobj(
2044
+ AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
2045
+ )
1740
2046
 
1741
2047
  logger.info(
1742
2048
  "KV cache pool is full. Retract requests. "
1743
- f"#retracted_reqs: {num_retracted_reqs}, "
1744
- f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
2049
+ f"#retracted_reqs: {len(retracted_reqs)}, "
2050
+ f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
2051
+ f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
1745
2052
  )
1746
2053
 
1747
- self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1748
- self.total_retracted_reqs += num_retracted_reqs
2054
+ for req in retracted_reqs:
2055
+ self._add_request_to_queue(req, is_retracted=True)
1749
2056
  else:
1750
2057
  self.new_token_ratio = max(
1751
2058
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -1773,37 +2080,66 @@ class Scheduler(
1773
2080
 
1774
2081
  # Run forward
1775
2082
  if self.is_generation:
2083
+
2084
+ batch_or_worker_batch = batch
2085
+
1776
2086
  if self.spec_algorithm.is_none():
1777
- model_worker_batch = batch.get_model_worker_batch()
2087
+ # FIXME(lsyin): remove this if and finally unify the abstraction
2088
+ batch_or_worker_batch = batch.get_model_worker_batch()
1778
2089
 
1779
- # update the consumer index of hicache to the running batch
1780
- self.tp_worker.set_hicache_consumer(
1781
- model_worker_batch.hicache_consumer_index
2090
+ if self.enable_overlap:
2091
+ # FIXME: remove this assert
2092
+ assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
2093
+ model_worker_batch = batch_or_worker_batch
2094
+ self.record_batch_in_overlap(model_worker_batch)
2095
+
2096
+ # Sampling info will be modified during forward
2097
+ model_worker_batch.sampling_info = (
2098
+ model_worker_batch.sampling_info.copy_for_forward()
1782
2099
  )
1783
- if self.pp_group.is_last_rank:
1784
- logits_output, next_token_ids, can_run_cuda_graph = (
1785
- self.tp_worker.forward_batch_generation(model_worker_batch)
1786
- )
1787
- else:
1788
- pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1789
- self.tp_worker.forward_batch_generation(model_worker_batch)
2100
+
2101
+ bs = len(model_worker_batch.seq_lens)
2102
+ future_indices = self.future_map.alloc_future_indices(bs)
2103
+
2104
+ with self.forward_stream_ctx:
2105
+ self.forward_stream.wait_stream(self.default_stream)
2106
+ self.future_map.resolve_future(model_worker_batch)
2107
+ if batch.sampling_info.grammars is not None:
2108
+ model_worker_batch.delay_sample_launch = True
2109
+ batch_result = self.model_worker.forward_batch_generation(
2110
+ batch_or_worker_batch
1790
2111
  )
1791
- bid = model_worker_batch.bid
2112
+ # FIXME(lsyin): maybe move this to forward_batch_generation
2113
+ batch_result.copy_done = torch.get_device_module(
2114
+ self.device
2115
+ ).Event()
2116
+ if not model_worker_batch.delay_sample_launch:
2117
+ self.future_map.store_to_map(
2118
+ future_indices, batch_result.next_token_ids
2119
+ )
2120
+ batch_result.copy_to_cpu()
2121
+ else:
2122
+ batch_result.future_indices = future_indices
2123
+
2124
+ # FIXME(lsyin): move this assignment elsewhere
2125
+ maybe_future_next_token_ids = -future_indices.indices
1792
2126
  else:
1793
- (
1794
- logits_output,
1795
- next_token_ids,
1796
- bid,
1797
- num_accepted_tokens,
1798
- can_run_cuda_graph,
1799
- ) = self.draft_worker.forward_batch_speculative_generation(batch)
1800
- bs = batch.batch_size()
1801
- self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
1802
- self.spec_num_total_forward_ct += bs
1803
- self.num_generated_tokens += num_accepted_tokens
1804
-
1805
- if self.pp_group.is_last_rank:
1806
- batch.output_ids = next_token_ids
2127
+ batch_result = self.model_worker.forward_batch_generation(
2128
+ batch_or_worker_batch
2129
+ )
2130
+ maybe_future_next_token_ids = batch_result.next_token_ids
2131
+
2132
+ if not self.spec_algorithm.is_none():
2133
+ # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
2134
+ self.update_spec_metrics(
2135
+ batch.batch_size(), batch_result.num_accepted_tokens
2136
+ )
2137
+
2138
+ # NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
2139
+ # which can probably be replaced by future_indices later [TODO(lsyin)].
2140
+ # we shall still keep the original outputs, e.g. next_token_ids
2141
+ # in the GenerationBatchOutput for processing after copy_done.
2142
+ batch.output_ids = maybe_future_next_token_ids
1807
2143
 
1808
2144
  # These 2 values are needed for processing the output, but the values can be
1809
2145
  # modified by overlap schedule. So we have to copy them here so that
@@ -1812,6 +2148,7 @@ class Scheduler(
1812
2148
  extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1813
2149
  else:
1814
2150
  extend_input_len_per_req = None
2151
+
1815
2152
  if batch.return_logprob:
1816
2153
  extend_logprob_start_len_per_req = [
1817
2154
  req.extend_logprob_start_len for req in batch.reqs
@@ -1819,43 +2156,60 @@ class Scheduler(
1819
2156
  else:
1820
2157
  extend_logprob_start_len_per_req = None
1821
2158
 
1822
- ret = GenerationBatchResult(
1823
- logits_output=logits_output if self.pp_group.is_last_rank else None,
1824
- pp_hidden_states_proxy_tensors=(
1825
- pp_hidden_states_proxy_tensors
1826
- if not self.pp_group.is_last_rank
1827
- else None
1828
- ),
1829
- next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1830
- extend_input_len_per_req=extend_input_len_per_req,
1831
- extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1832
- bid=bid,
1833
- can_run_cuda_graph=can_run_cuda_graph,
2159
+ batch_result.extend_input_len_per_req = extend_input_len_per_req
2160
+ batch_result.extend_logprob_start_len_per_req = (
2161
+ extend_logprob_start_len_per_req
1834
2162
  )
2163
+ return batch_result
1835
2164
  else: # embedding or reward model
1836
2165
  model_worker_batch = batch.get_model_worker_batch()
1837
2166
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1838
- ret = EmbeddingBatchResult(
1839
- embeddings=embeddings, bid=model_worker_batch.bid
1840
- )
2167
+ ret = EmbeddingBatchResult(embeddings=embeddings)
1841
2168
  return ret
1842
2169
 
2170
+ def launch_last_batch_sample_if_needed(
2171
+ self,
2172
+ ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
2173
+ if len(self.result_queue) == 0:
2174
+ return
2175
+
2176
+ tmp_batch, tmp_result = self.result_queue.popleft()
2177
+
2178
+ tmp_result: GenerationBatchResult
2179
+ if not tmp_result.delay_sample_launch:
2180
+ self.result_queue.appendleft((tmp_batch, tmp_result))
2181
+ return
2182
+
2183
+ with self.forward_stream_ctx:
2184
+ self.forward_stream.wait_stream(self.default_stream)
2185
+ tmp_result.next_token_ids = self.model_worker.model_runner.sample(
2186
+ tmp_result.logits_output,
2187
+ tmp_result.forward_batch,
2188
+ )
2189
+ future_indices = tmp_result.future_indices
2190
+ self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
2191
+ tmp_result.copy_to_cpu()
2192
+ self.result_queue.appendleft((tmp_batch, tmp_result))
2193
+
1843
2194
  def process_batch_result(
1844
2195
  self,
1845
2196
  batch: ScheduleBatch,
1846
2197
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
1847
- launch_done: Optional[threading.Event] = None,
1848
2198
  ):
1849
2199
  if batch.forward_mode.is_decode():
1850
- self.process_batch_result_decode(batch, result, launch_done)
2200
+ self.process_batch_result_decode(batch, result)
2201
+ if self.enable_trace:
2202
+ trace_slice_batch("decode loop", batch.reqs)
2203
+
1851
2204
  elif batch.forward_mode.is_extend():
1852
- self.process_batch_result_prefill(batch, result, launch_done)
2205
+ self.process_batch_result_prefill(batch, result)
2206
+ if self.enable_trace:
2207
+ trace_slice_batch("prefill", batch.reqs)
2208
+
1853
2209
  elif batch.forward_mode.is_idle():
1854
2210
  if self.enable_overlap:
1855
- self.tp_worker.resolve_last_batch_result(launch_done)
1856
- self.set_next_batch_sampling_info_done(batch)
1857
- elif batch.forward_mode.is_dummy_first():
1858
- self.set_next_batch_sampling_info_done(batch)
2211
+ if result.copy_done is not None:
2212
+ result.copy_done.synchronize()
1859
2213
 
1860
2214
  self.maybe_send_health_check_signal()
1861
2215
 
@@ -2008,12 +2362,13 @@ class Scheduler(
2008
2362
  if req.finished(): # It is aborted by AbortReq
2009
2363
  num_ready_reqs += 1
2010
2364
  continue
2365
+
2011
2366
  req.grammar = req.grammar.result(timeout=0.03)
2012
2367
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2013
2368
  if req.grammar is INVALID_GRAMMAR_OBJ:
2014
- req.set_finish_with_abort(
2015
- f"Invalid grammar request: {req.grammar_key=}"
2016
- )
2369
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2370
+ req.set_finish_with_abort(error_msg)
2371
+
2017
2372
  num_ready_reqs += 1
2018
2373
  except futures._base.TimeoutError:
2019
2374
  req.grammar_wait_ct += 1
@@ -2045,9 +2400,8 @@ class Scheduler(
2045
2400
  req.grammar = req.grammar.result()
2046
2401
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2047
2402
  if req.grammar is INVALID_GRAMMAR_OBJ:
2048
- req.set_finish_with_abort(
2049
- f"Invalid grammar request: {req.grammar_key=}"
2050
- )
2403
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2404
+ req.set_finish_with_abort(error_msg)
2051
2405
  else:
2052
2406
  num_ready_reqs_max = num_ready_reqs
2053
2407
  num_timeout_reqs_max = num_timeout_reqs
@@ -2055,21 +2409,16 @@ class Scheduler(
2055
2409
  for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
2056
2410
  req = self.grammar_queue[i]
2057
2411
  req.grammar.cancel()
2412
+ self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2058
2413
  error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
2059
2414
  req.set_finish_with_abort(error_msg)
2060
- self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2415
+
2061
2416
  num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2062
2417
 
2063
- self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2418
+ for req in self.grammar_queue[:num_ready_reqs]:
2419
+ self._add_request_to_queue(req)
2064
2420
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
2065
2421
 
2066
- def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
2067
- if batch.next_batch_sampling_info:
2068
- if batch.next_batch_sampling_info.grammars is not None:
2069
- batch.next_batch_sampling_info.update_regex_vocab_mask()
2070
- self.current_stream.synchronize()
2071
- batch.next_batch_sampling_info.sampling_info_done.set()
2072
-
2073
2422
  def watchdog_thread(self):
2074
2423
  """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
2075
2424
  self.watchdog_last_forward_ct = 0
@@ -2152,9 +2501,8 @@ class Scheduler(
2152
2501
  self.req_to_token_pool.clear()
2153
2502
  self.token_to_kv_pool_allocator.clear()
2154
2503
 
2155
- if not self.spec_algorithm.is_none():
2156
- self.draft_worker.model_runner.req_to_token_pool.clear()
2157
- self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2504
+ if self.draft_worker:
2505
+ self.draft_worker.clear_cache_pool()
2158
2506
 
2159
2507
  self.num_generated_tokens = 0
2160
2508
  self.forward_ct_decode = 0
@@ -2174,39 +2522,50 @@ class Scheduler(
2174
2522
  if_success = False
2175
2523
  return if_success
2176
2524
 
2177
- def get_load(self):
2525
+ def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
2178
2526
  # TODO(lsyin): use dynamically maintained num_waiting_tokens
2527
+
2179
2528
  if self.is_hybrid:
2180
- load_full = (
2529
+ num_tokens_full = (
2181
2530
  self.full_tokens_per_layer
2182
2531
  - self.token_to_kv_pool_allocator.full_available_size()
2183
2532
  - self.tree_cache.full_evictable_size()
2184
2533
  )
2185
- load_swa = (
2534
+ num_tokens_swa = (
2186
2535
  self.swa_tokens_per_layer
2187
2536
  - self.token_to_kv_pool_allocator.swa_available_size()
2188
2537
  - self.tree_cache.swa_evictable_size()
2189
2538
  )
2190
- load = max(load_full, load_swa)
2539
+ num_tokens = max(num_tokens_full, num_tokens_swa)
2191
2540
  else:
2192
- load = (
2541
+ num_tokens = (
2193
2542
  self.max_total_num_tokens
2194
2543
  - self.token_to_kv_pool_allocator.available_size()
2195
2544
  - self.tree_cache.evictable_size()
2196
2545
  )
2197
- load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
2546
+
2547
+ # Tokens in waiting queue, bootstrap queue, prealloc queue
2548
+ num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
2549
+ num_waiting_reqs = len(self.waiting_queue)
2198
2550
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2199
- load += sum(
2551
+ num_tokens += sum(
2200
2552
  len(req.origin_input_ids)
2201
2553
  for req in self.disagg_prefill_bootstrap_queue.queue
2202
2554
  )
2555
+ num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
2203
2556
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
2204
- load += sum(
2557
+ num_tokens += sum(
2205
2558
  len(req.req.origin_input_ids)
2206
2559
  for req in self.disagg_decode_prealloc_queue.queue
2207
2560
  )
2561
+ num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
2208
2562
 
2209
- return load
2563
+ return GetLoadReqOutput(
2564
+ dp_rank=self.dp_rank,
2565
+ num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
2566
+ num_waiting_reqs=num_waiting_reqs,
2567
+ num_tokens=num_tokens,
2568
+ )
2210
2569
 
2211
2570
  def get_internal_state(self, recv_req: GetInternalStateReq):
2212
2571
  ret = dict(global_server_args_dict)
@@ -2221,10 +2580,9 @@ class Scheduler(
2221
2580
  "token_capacity": int(self.max_total_num_tokens),
2222
2581
  }
2223
2582
 
2224
- if not _is_cpu:
2225
- ret["memory_usage"]["cuda_graph"] = round(
2226
- self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
2227
- )
2583
+ ret["memory_usage"]["graph"] = round(
2584
+ self.tp_worker.worker.model_runner.graph_mem_usage, 2
2585
+ )
2228
2586
 
2229
2587
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2230
2588
  ret["avg_spec_accept_length"] = (
@@ -2233,15 +2591,13 @@ class Scheduler(
2233
2591
  if RECORD_STEP_TIME:
2234
2592
  ret["step_time_dict"] = self.step_time_dict
2235
2593
 
2236
- ret["load"] = self.get_load()
2237
-
2238
2594
  return GetInternalStateReqOutput(internal_state=ret)
2239
2595
 
2240
2596
  def set_internal_state(self, recv_req: SetInternalStateReq):
2241
2597
  server_args_dict = recv_req.server_args
2242
2598
  args_allow_update = set(
2243
2599
  [
2244
- "max_micro_batch_size",
2600
+ "pp_max_micro_batch_size",
2245
2601
  "speculative_accept_threshold_single",
2246
2602
  "speculative_accept_threshold_acc",
2247
2603
  ]
@@ -2252,7 +2608,7 @@ class Scheduler(
2252
2608
  logging.warning(f"Updating {k} is not supported.")
2253
2609
  if_success = False
2254
2610
  break
2255
- elif k == "max_micro_batch_size" and (
2611
+ elif k == "pp_max_micro_batch_size" and (
2256
2612
  v > self.max_running_requests // self.pp_size or v < 1
2257
2613
  ):
2258
2614
  logging.warning(
@@ -2310,7 +2666,7 @@ class Scheduler(
2310
2666
  if self.enable_hicache_storage:
2311
2667
  # to release prefetch events associated with the request
2312
2668
  self.tree_cache.release_aborted_request(req.rid)
2313
- self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2669
+ self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
2314
2670
  # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2315
2671
  if self.disaggregation_mode == DisaggregationMode.DECODE:
2316
2672
  self.tree_cache.cache_finished_req(req)
@@ -2331,31 +2687,31 @@ class Scheduler(
2331
2687
  # Delete requests not in the waiting queue when PD disaggregation is enabled
2332
2688
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2333
2689
  # Abort requests that have not yet been bootstrapped
2334
- for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
2335
- logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2690
+ for req in self.disagg_prefill_bootstrap_queue.queue:
2336
2691
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2692
+ logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2337
2693
  if hasattr(req.disagg_kv_sender, "abort"):
2338
2694
  req.disagg_kv_sender.abort()
2339
2695
 
2340
2696
  # Abort in-flight requests
2341
- for i, req in enumerate(self.disagg_prefill_inflight_queue):
2342
- logger.debug(f"Abort inflight queue request. {req.rid=}")
2697
+ for req in self.disagg_prefill_inflight_queue:
2343
2698
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2699
+ logger.debug(f"Abort inflight queue request. {req.rid=}")
2344
2700
  if hasattr(req.disagg_kv_sender, "abort"):
2345
2701
  req.disagg_kv_sender.abort()
2346
2702
 
2347
2703
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
2348
2704
  # Abort requests that have not yet finished preallocation
2349
- for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
2350
- logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2705
+ for decode_req in self.disagg_decode_prealloc_queue.queue:
2351
2706
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2707
+ logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2352
2708
  if hasattr(decode_req.kv_receiver, "abort"):
2353
2709
  decode_req.kv_receiver.abort()
2354
2710
 
2355
2711
  # Abort requests waiting for kvcache to release tree cache
2356
- for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
2357
- logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2712
+ for decode_req in self.disagg_decode_transfer_queue.queue:
2358
2713
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2714
+ logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2359
2715
  if hasattr(decode_req.kv_receiver, "abort"):
2360
2716
  decode_req.kv_receiver.abort()
2361
2717
 
@@ -2398,6 +2754,22 @@ class Scheduler(
2398
2754
  self.send_to_detokenizer.send_pyobj(recv_req)
2399
2755
  return recv_req
2400
2756
 
2757
+ def init_weights_send_group_for_remote_instance(
2758
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
2759
+ ):
2760
+ """Init the seed and client instance communication group."""
2761
+ success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
2762
+ recv_req
2763
+ )
2764
+ return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
2765
+
2766
+ def send_weights_to_remote_instance(
2767
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
2768
+ ):
2769
+ """Send the seed instance weights to the destination instance."""
2770
+ success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
2771
+ return SendWeightsToRemoteInstanceReqOutput(success, message)
2772
+
2401
2773
  def slow_down(self, recv_req: SlowDownReqInput):
2402
2774
  t = recv_req.forward_sleep_time
2403
2775
  if t is not None and t <= 0:
@@ -2406,11 +2778,12 @@ class Scheduler(
2406
2778
  return SlowDownReqOutput()
2407
2779
 
2408
2780
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2409
- if recv_req == ExpertDistributionReq.START_RECORD:
2781
+ action = recv_req.action
2782
+ if action == ExpertDistributionReqType.START_RECORD:
2410
2783
  get_global_expert_distribution_recorder().start_record()
2411
- elif recv_req == ExpertDistributionReq.STOP_RECORD:
2784
+ elif action == ExpertDistributionReqType.STOP_RECORD:
2412
2785
  get_global_expert_distribution_recorder().stop_record()
2413
- elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2786
+ elif action == ExpertDistributionReqType.DUMP_RECORD:
2414
2787
  get_global_expert_distribution_recorder().dump_record()
2415
2788
  else:
2416
2789
  raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
@@ -2493,7 +2866,8 @@ class IdleSleeper:
2493
2866
 
2494
2867
 
2495
2868
  def is_health_check_generate_req(recv_req):
2496
- return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2869
+ rid = getattr(recv_req, "rid", None)
2870
+ return rid is not None and rid.startswith("HEALTH_CHECK")
2497
2871
 
2498
2872
 
2499
2873
  def is_work_request(recv_req):
@@ -2517,10 +2891,12 @@ def run_scheduler_process(
2517
2891
  pp_rank: int,
2518
2892
  dp_rank: Optional[int],
2519
2893
  pipe_writer,
2520
- balance_meta: Optional[DPBalanceMeta] = None,
2521
2894
  ):
2522
- # Generate the prefix
2895
+ # Generate the logger prefix
2523
2896
  prefix = ""
2897
+ if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2898
+ # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2899
+ dp_rank = int(os.environ["SGLANG_DP_RANK"])
2524
2900
  if dp_rank is not None:
2525
2901
  prefix += f" DP{dp_rank}"
2526
2902
  if server_args.tp_size > 1:
@@ -2536,10 +2912,6 @@ def run_scheduler_process(
2536
2912
  kill_itself_when_parent_died()
2537
2913
  parent_process = psutil.Process().parent()
2538
2914
 
2539
- # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2540
- if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2541
- dp_rank = int(os.environ["SGLANG_DP_RANK"])
2542
-
2543
2915
  # Configure the logger
2544
2916
  configure_logger(server_args, prefix=prefix)
2545
2917
  suppress_other_loggers()
@@ -2547,6 +2919,15 @@ def run_scheduler_process(
2547
2919
  # Set cpu affinity to this gpu process
2548
2920
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2549
2921
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2922
+ if (numa_node := server_args.numa_node) is not None:
2923
+ numa_bind_to_node(numa_node[gpu_id])
2924
+
2925
+ # Set up tracing
2926
+ if server_args.enable_trace:
2927
+ process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2928
+ if server_args.disaggregation_mode == "null":
2929
+ thread_label = "Scheduler"
2930
+ trace_set_thread_info(thread_label, tp_rank, dp_rank)
2550
2931
 
2551
2932
  # Create a scheduler and run the event loop
2552
2933
  try:
@@ -2558,7 +2939,6 @@ def run_scheduler_process(
2558
2939
  moe_ep_rank,
2559
2940
  pp_rank,
2560
2941
  dp_rank,
2561
- dp_balance_meta=balance_meta,
2562
2942
  )
2563
2943
  pipe_writer.send(
2564
2944
  {