sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -24,16 +24,16 @@ from collections import deque
24
24
  from concurrent import futures
25
25
  from dataclasses import dataclass
26
26
  from http import HTTPStatus
27
- from types import SimpleNamespace
28
- from typing import Dict, List, Optional, Tuple, Union
27
+ from typing import Deque, Dict, List, Optional, Tuple, Union
29
28
 
30
29
  import psutil
31
30
  import setproctitle
32
31
  import torch
33
32
  import zmq
33
+ from torch.cuda import Stream as CudaStream
34
+ from torch.cuda import StreamContext as CudaStreamContext
34
35
  from torch.distributed import barrier
35
36
 
36
- from sglang.global_config import global_config
37
37
  from sglang.srt.configs.model_config import ModelConfig
38
38
  from sglang.srt.constrained.base_grammar_backend import (
39
39
  INVALID_GRAMMAR_OBJ,
@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
44
44
  DecodeTransferQueue,
45
45
  SchedulerDisaggregationDecodeMixin,
46
46
  )
47
+ from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
48
+ DecodeKVCacheOffloadManager,
49
+ )
47
50
  from sglang.srt.disaggregation.prefill import (
48
51
  PrefillBootstrapQueue,
49
52
  SchedulerDisaggregationPrefillMixin,
@@ -56,24 +59,23 @@ from sglang.srt.disaggregation.utils import (
56
59
  prepare_abort,
57
60
  )
58
61
  from sglang.srt.distributed import get_pp_group, get_world_group
62
+ from sglang.srt.environ import envs
59
63
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
60
- from sglang.srt.hf_transformers_utils import (
61
- get_processor,
62
- get_tokenizer,
63
- get_tokenizer_from_processor,
64
- )
65
64
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
65
  from sglang.srt.layers.moe import initialize_moe_config
68
66
  from sglang.srt.managers.io_struct import (
69
67
  AbortReq,
68
+ BaseBatchReq,
69
+ BaseReq,
70
70
  BatchTokenizedEmbeddingReqInput,
71
71
  BatchTokenizedGenerateReqInput,
72
72
  ClearHiCacheReqInput,
73
73
  ClearHiCacheReqOutput,
74
74
  CloseSessionReqInput,
75
+ DestroyWeightsUpdateGroupReqInput,
75
76
  ExpertDistributionReq,
76
77
  ExpertDistributionReqOutput,
78
+ ExpertDistributionReqType,
77
79
  FlushCacheReqInput,
78
80
  FlushCacheReqOutput,
79
81
  FreezeGCReq,
@@ -88,8 +90,6 @@ from sglang.srt.managers.io_struct import (
88
90
  InitWeightsUpdateGroupReqInput,
89
91
  LoadLoRAAdapterReqInput,
90
92
  LoadLoRAAdapterReqOutput,
91
- MultiTokenizerRegisterReq,
92
- MultiTokenizerWrapper,
93
93
  OpenSessionReqInput,
94
94
  OpenSessionReqOutput,
95
95
  ProfileReq,
@@ -109,15 +109,18 @@ from sglang.srt.managers.io_struct import (
109
109
  UnloadLoRAAdapterReqOutput,
110
110
  UpdateWeightFromDiskReqInput,
111
111
  UpdateWeightsFromDistributedReqInput,
112
+ UpdateWeightsFromIPCReqInput,
112
113
  UpdateWeightsFromTensorReqInput,
113
114
  )
114
115
  from sglang.srt.managers.mm_utils import init_embedding_cache
116
+ from sglang.srt.managers.overlap_utils import FutureMap
115
117
  from sglang.srt.managers.schedule_batch import (
116
118
  FINISH_ABORT,
119
+ ModelWorkerBatch,
117
120
  MultimodalInputs,
118
121
  Req,
122
+ RequestStage,
119
123
  ScheduleBatch,
120
- global_server_args_dict,
121
124
  )
122
125
  from sglang.srt.managers.schedule_policy import (
123
126
  AddReqResult,
@@ -132,31 +135,30 @@ from sglang.srt.managers.scheduler_metrics_mixin import (
132
135
  from sglang.srt.managers.scheduler_output_processor_mixin import (
133
136
  SchedulerOutputProcessorMixin,
134
137
  )
138
+ from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
135
139
  from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
136
140
  from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
141
+ from sglang.srt.managers.scheduler_runtime_checker_mixin import (
142
+ SchedulerRuntimeCheckerMixin,
143
+ )
137
144
  from sglang.srt.managers.scheduler_update_weights_mixin import (
138
145
  SchedulerUpdateWeightsMixin,
139
146
  )
140
147
  from sglang.srt.managers.session_controller import Session
141
- from sglang.srt.managers.tp_worker import TpModelWorker
142
- from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
143
- from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
148
+ from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length
144
149
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
145
150
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
146
- from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
151
+ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
147
152
  from sglang.srt.mem_cache.radix_cache import RadixCache
148
153
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
149
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
150
154
  from sglang.srt.parser.reasoning_parser import ReasoningParser
151
- from sglang.srt.server_args import PortArgs, ServerArgs
155
+ from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
152
156
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
153
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
154
157
  from sglang.srt.tracing.trace import (
155
158
  process_tracing_init,
156
- trace_event,
157
159
  trace_set_proc_propagate_context,
158
160
  trace_set_thread_info,
159
- trace_slice,
161
+ trace_slice_batch,
160
162
  trace_slice_end,
161
163
  trace_slice_start,
162
164
  )
@@ -170,8 +172,8 @@ from sglang.srt.utils import (
170
172
  freeze_gc,
171
173
  get_available_gpu_memory,
172
174
  get_bool_env_var,
175
+ get_int_env_var,
173
176
  get_zmq_socket,
174
- is_cpu,
175
177
  kill_itself_when_parent_died,
176
178
  numa_bind_to_node,
177
179
  point_to_point_pyobj,
@@ -182,32 +184,25 @@ from sglang.srt.utils import (
182
184
  set_random_seed,
183
185
  suppress_other_loggers,
184
186
  )
187
+ from sglang.srt.utils.hf_transformers_utils import (
188
+ get_processor,
189
+ get_tokenizer,
190
+ get_tokenizer_from_processor,
191
+ )
192
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
185
193
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
186
194
 
187
195
  logger = logging.getLogger(__name__)
188
196
 
189
197
  # Test retract decode for debugging purposes
190
- TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
198
+ TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
199
+ TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
191
200
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
192
201
 
193
- _is_cpu = is_cpu()
194
-
195
-
196
- @dataclass
197
- class GenerationBatchResult:
198
- logits_output: Optional[LogitsProcessorOutput]
199
- pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
200
- next_token_ids: Optional[List[int]]
201
- extend_input_len_per_req: List[int]
202
- extend_logprob_start_len_per_req: List[int]
203
- bid: int
204
- can_run_cuda_graph: bool
205
-
206
202
 
207
203
  @dataclass
208
204
  class EmbeddingBatchResult:
209
205
  embeddings: torch.Tensor
210
- bid: int
211
206
 
212
207
 
213
208
  class Scheduler(
@@ -217,6 +212,8 @@ class Scheduler(
217
212
  SchedulerMetricsMixin,
218
213
  SchedulerDisaggregationDecodeMixin,
219
214
  SchedulerDisaggregationPrefillMixin,
215
+ SchedulerRuntimeCheckerMixin,
216
+ SchedulerPPMixin,
220
217
  ):
221
218
  """A scheduler that manages a tensor parallel GPU worker."""
222
219
 
@@ -229,7 +226,6 @@ class Scheduler(
229
226
  moe_ep_rank: int,
230
227
  pp_rank: int,
231
228
  dp_rank: Optional[int],
232
- dp_balance_meta: Optional[DPBalanceMeta] = None,
233
229
  ):
234
230
  # Parse args
235
231
  self.server_args = server_args
@@ -242,6 +238,16 @@ class Scheduler(
242
238
  self.pp_size = server_args.pp_size
243
239
  self.dp_size = server_args.dp_size
244
240
  self.schedule_policy = server_args.schedule_policy
241
+ self.enable_priority_scheduling = server_args.enable_priority_scheduling
242
+ self.abort_on_priority_when_disabled = (
243
+ server_args.abort_on_priority_when_disabled
244
+ )
245
+ self.schedule_low_priority_values_first = (
246
+ server_args.schedule_low_priority_values_first
247
+ )
248
+ self.priority_scheduling_preemption_threshold = (
249
+ server_args.priority_scheduling_preemption_threshold
250
+ )
245
251
  self.enable_lora = server_args.enable_lora
246
252
  self.max_loras_per_batch = server_args.max_loras_per_batch
247
253
  self.enable_overlap = not server_args.disable_overlap_schedule
@@ -250,7 +256,10 @@ class Scheduler(
250
256
  self.enable_metrics_for_all_schedulers = (
251
257
  server_args.enable_metrics_for_all_schedulers
252
258
  )
253
- self.enable_kv_cache_events = server_args.kv_events_config is not None
259
+ self.enable_kv_cache_events = bool(
260
+ server_args.kv_events_config and tp_rank == 0
261
+ )
262
+ self.enable_trace = server_args.enable_trace
254
263
  self.stream_interval = server_args.stream_interval
255
264
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
256
265
  server_args.speculative_algorithm
@@ -273,47 +282,7 @@ class Scheduler(
273
282
  self.model_config = ModelConfig.from_server_args(server_args)
274
283
 
275
284
  # Init inter-process communication
276
- context = zmq.Context(2)
277
- self.idle_sleeper = None
278
- if self.pp_rank == 0 and self.attn_tp_rank == 0:
279
- self.recv_from_tokenizer = get_zmq_socket(
280
- context, zmq.PULL, port_args.scheduler_input_ipc_name, False
281
- )
282
- self.recv_from_rpc = get_zmq_socket(
283
- context, zmq.DEALER, port_args.rpc_ipc_name, False
284
- )
285
-
286
- self.send_to_tokenizer = get_zmq_socket(
287
- context, zmq.PUSH, port_args.tokenizer_ipc_name, False
288
- )
289
- if server_args.skip_tokenizer_init:
290
- # Directly send to the TokenizerManager
291
- self.send_to_detokenizer = get_zmq_socket(
292
- context, zmq.PUSH, port_args.tokenizer_ipc_name, False
293
- )
294
- else:
295
- # Send to the DetokenizerManager
296
- self.send_to_detokenizer = get_zmq_socket(
297
- context, zmq.PUSH, port_args.detokenizer_ipc_name, False
298
- )
299
-
300
- if self.server_args.sleep_on_idle:
301
- self.idle_sleeper = IdleSleeper(
302
- [
303
- self.recv_from_tokenizer,
304
- self.recv_from_rpc,
305
- ]
306
- )
307
- else:
308
- self.recv_from_tokenizer = None
309
- self.recv_from_rpc = None
310
- self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
311
- self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
312
-
313
- if self.current_scheduler_metrics_enabled():
314
- self.send_metrics_from_scheduler = get_zmq_socket(
315
- context, zmq.PUSH, port_args.metrics_ipc_name, False
316
- )
285
+ self.init_sockets(server_args, port_args)
317
286
 
318
287
  # Init tokenizer
319
288
  self.init_tokenizer()
@@ -336,12 +305,10 @@ class Scheduler(
336
305
  logger.info("Overlap scheduler is disabled for embedding models.")
337
306
 
338
307
  # Launch a tensor parallel worker
339
- if self.enable_overlap:
340
- TpWorkerClass = TpModelWorkerClient
341
- else:
342
- TpWorkerClass = TpModelWorker
343
308
 
344
- self.tp_worker = TpWorkerClass(
309
+ from sglang.srt.managers.tp_worker import TpModelWorker
310
+
311
+ self.tp_worker = TpModelWorker(
345
312
  server_args=server_args,
346
313
  gpu_id=gpu_id,
347
314
  tp_rank=tp_rank,
@@ -352,32 +319,16 @@ class Scheduler(
352
319
  )
353
320
 
354
321
  # Launch a draft worker for speculative decoding
355
- if self.spec_algorithm.is_eagle():
356
- from sglang.srt.speculative.eagle_worker import EAGLEWorker
357
322
 
358
- self.draft_worker = EAGLEWorker(
359
- gpu_id=gpu_id,
360
- tp_rank=tp_rank,
361
- moe_ep_rank=moe_ep_rank,
362
- server_args=server_args,
363
- nccl_port=port_args.nccl_port,
364
- target_worker=self.tp_worker,
365
- dp_rank=dp_rank,
366
- )
367
- elif self.spec_algorithm.is_standalone():
368
- from sglang.srt.speculative.standalone_worker import StandaloneWorker
323
+ self.launch_draft_worker(
324
+ gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
325
+ )
369
326
 
370
- self.draft_worker = StandaloneWorker(
371
- gpu_id=gpu_id,
372
- tp_rank=tp_rank,
373
- moe_ep_rank=moe_ep_rank,
374
- server_args=server_args,
375
- nccl_port=port_args.nccl_port,
376
- target_worker=self.tp_worker,
377
- dp_rank=dp_rank,
378
- )
327
+ # Dispatch the model worker
328
+ if self.spec_algorithm.is_none():
329
+ self.model_worker = self.tp_worker
379
330
  else:
380
- self.draft_worker = None
331
+ self.model_worker = self.draft_worker
381
332
 
382
333
  # Get token and memory info from the model worker
383
334
  (
@@ -389,13 +340,12 @@ class Scheduler(
389
340
  self.max_req_input_len,
390
341
  self.random_seed,
391
342
  self.device,
392
- worker_global_server_args_dict,
393
343
  _,
394
344
  _,
395
345
  _,
396
346
  ) = self.tp_worker.get_worker_info()
397
- if global_server_args_dict["max_micro_batch_size"] is None:
398
- global_server_args_dict["max_micro_batch_size"] = max(
347
+ if get_global_server_args().pp_max_micro_batch_size is None:
348
+ get_global_server_args().pp_max_micro_batch_size = max(
399
349
  self.max_running_requests // server_args.pp_size, 1
400
350
  )
401
351
 
@@ -407,11 +357,12 @@ class Scheduler(
407
357
  self.world_group = get_world_group()
408
358
 
409
359
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
410
- global_server_args_dict.update(worker_global_server_args_dict)
411
360
  set_random_seed(self.random_seed)
412
361
 
413
362
  # Hybrid memory pool
414
363
  self.is_hybrid = self.tp_worker.is_hybrid
364
+ self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None
365
+
415
366
  if self.is_hybrid:
416
367
  self.sliding_window_size = self.tp_worker.sliding_window_size
417
368
  self.full_tokens_per_layer, self.swa_tokens_per_layer = (
@@ -455,9 +406,11 @@ class Scheduler(
455
406
  self.kv_transfer_speed_gb_s: float = 0.0
456
407
  self.kv_transfer_latency_ms: float = 0.0
457
408
  self.sessions: Dict[str, Session] = {}
458
- self.current_stream = torch.get_device_module(self.device).current_stream()
409
+ self.default_stream: CudaStream = torch.get_device_module(
410
+ self.device
411
+ ).current_stream()
459
412
  if self.device == "cpu":
460
- self.current_stream.synchronize = lambda: None # No-op for CPU
413
+ self.default_stream.synchronize = lambda: None # No-op for CPU
461
414
  self.forward_sleep_time = None
462
415
 
463
416
  # Init chunked prefill
@@ -486,23 +439,27 @@ class Scheduler(
486
439
  self.schedule_policy,
487
440
  self.tree_cache,
488
441
  self.enable_hierarchical_cache,
442
+ self.enable_priority_scheduling,
443
+ self.schedule_low_priority_values_first,
489
444
  )
445
+ # Enable preemption for priority scheduling.
446
+ self.try_preemption = self.enable_priority_scheduling
447
+
490
448
  assert (
491
449
  server_args.schedule_conservativeness >= 0
492
450
  ), "Invalid schedule_conservativeness"
493
451
  self.init_new_token_ratio = min(
494
- global_config.default_init_new_token_ratio
452
+ envs.SGLANG_INIT_NEW_TOKEN_RATIO.get()
495
453
  * server_args.schedule_conservativeness,
496
454
  1.0,
497
455
  )
498
456
  self.min_new_token_ratio = min(
499
- self.init_new_token_ratio
500
- * global_config.default_min_new_token_ratio_factor,
457
+ self.init_new_token_ratio * envs.SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR.get(),
501
458
  1.0,
502
459
  )
503
460
  self.new_token_ratio_decay = (
504
461
  self.init_new_token_ratio - self.min_new_token_ratio
505
- ) / global_config.default_new_token_ratio_decay_steps
462
+ ) / envs.SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS.get()
506
463
  self.new_token_ratio = self.init_new_token_ratio
507
464
 
508
465
  # Init watchdog thread
@@ -527,8 +484,9 @@ class Scheduler(
527
484
 
528
485
  # Init metrics stats
529
486
  self.init_metrics(tp_rank, pp_rank, dp_rank)
530
- self.init_kv_events(server_args.kv_events_config)
531
- self.init_dp_balance(dp_balance_meta)
487
+
488
+ if self.enable_kv_cache_events:
489
+ self.init_kv_events(server_args.kv_events_config)
532
490
 
533
491
  # Init disaggregation
534
492
  self.disaggregation_mode = DisaggregationMode(
@@ -539,6 +497,12 @@ class Scheduler(
539
497
  if get_bool_env_var("SGLANG_GC_LOG"):
540
498
  configure_gc_logger()
541
499
 
500
+ # Init prefill kv split size when deterministic inference is enabled with various attention backends
501
+ self.init_deterministic_inference_config()
502
+
503
+ # Init overlap
504
+ self.init_overlap()
505
+
542
506
  # Init request dispatcher
543
507
  self._request_dispatcher = TypeBasedDispatcher(
544
508
  [
@@ -553,6 +517,7 @@ class Scheduler(
553
517
  (CloseSessionReqInput, self.close_session),
554
518
  (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
555
519
  (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
520
+ (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
556
521
  (
557
522
  InitWeightsSendGroupForRemoteInstanceReqInput,
558
523
  self.init_weights_send_group_for_remote_instance,
@@ -566,6 +531,7 @@ class Scheduler(
566
531
  self.update_weights_from_distributed,
567
532
  ),
568
533
  (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
534
+ (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
569
535
  (GetWeightsByNameReqInput, self.get_weights_by_name),
570
536
  (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
571
537
  (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
@@ -578,11 +544,147 @@ class Scheduler(
578
544
  (ExpertDistributionReq, self.expert_distribution_handle),
579
545
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
580
546
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
581
- (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
582
547
  (GetLoadReqInput, self.get_load),
583
548
  ]
584
549
  )
585
550
 
551
+ def launch_draft_worker(
552
+ self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
553
+ ):
554
+ if server_args.speculative_draft_load_format is not None:
555
+ server_args.load_format = server_args.speculative_draft_load_format
556
+ logger.info(
557
+ f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
558
+ )
559
+
560
+ if self.spec_algorithm.is_eagle():
561
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker
562
+ from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
563
+
564
+ WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
565
+
566
+ self.draft_worker = WorkerClass(
567
+ gpu_id=gpu_id,
568
+ tp_rank=tp_rank,
569
+ moe_ep_rank=moe_ep_rank,
570
+ server_args=server_args,
571
+ nccl_port=port_args.nccl_port,
572
+ target_worker=self.tp_worker,
573
+ dp_rank=dp_rank,
574
+ )
575
+ elif self.spec_algorithm.is_standalone():
576
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
577
+
578
+ self.draft_worker = StandaloneWorker(
579
+ gpu_id=gpu_id,
580
+ tp_rank=tp_rank,
581
+ moe_ep_rank=moe_ep_rank,
582
+ server_args=server_args,
583
+ nccl_port=port_args.nccl_port,
584
+ target_worker=self.tp_worker,
585
+ dp_rank=dp_rank,
586
+ )
587
+ elif self.spec_algorithm.is_ngram():
588
+ from sglang.srt.speculative.ngram_worker import NGRAMWorker
589
+
590
+ self.draft_worker = NGRAMWorker(
591
+ gpu_id=gpu_id,
592
+ tp_rank=tp_rank,
593
+ moe_ep_rank=moe_ep_rank,
594
+ server_args=server_args,
595
+ nccl_port=port_args.nccl_port,
596
+ target_worker=self.tp_worker,
597
+ dp_rank=dp_rank,
598
+ )
599
+ else:
600
+ self.draft_worker = None
601
+
602
+ def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
603
+ context = zmq.Context(2)
604
+ self.idle_sleeper = None
605
+
606
+ class SenderWrapper:
607
+ def __init__(self, socket: zmq.Socket):
608
+ self.socket = socket
609
+
610
+ def send_output(
611
+ self,
612
+ output: Union[BaseReq, BaseBatchReq],
613
+ recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
614
+ ):
615
+ if self.socket is None:
616
+ return
617
+
618
+ if (
619
+ isinstance(recv_obj, BaseReq)
620
+ and recv_obj.http_worker_ipc is not None
621
+ and output.http_worker_ipc is None
622
+ ):
623
+ # handle communicator reqs for multi-http worker case
624
+ output.http_worker_ipc = recv_obj.http_worker_ipc
625
+
626
+ self.socket.send_pyobj(output)
627
+
628
+ if self.pp_rank == 0 and self.attn_tp_rank == 0:
629
+ self.recv_from_tokenizer = get_zmq_socket(
630
+ context, zmq.PULL, port_args.scheduler_input_ipc_name, False
631
+ )
632
+ self.recv_from_rpc = get_zmq_socket(
633
+ context, zmq.DEALER, port_args.rpc_ipc_name, False
634
+ )
635
+
636
+ send_to_tokenizer = get_zmq_socket(
637
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
638
+ )
639
+ if server_args.skip_tokenizer_init:
640
+ # Directly send to the TokenizerManager
641
+ send_to_detokenizer = get_zmq_socket(
642
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
643
+ )
644
+ else:
645
+ # Send to the DetokenizerManager
646
+ send_to_detokenizer = get_zmq_socket(
647
+ context, zmq.PUSH, port_args.detokenizer_ipc_name, False
648
+ )
649
+
650
+ self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
651
+ self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)
652
+
653
+ if self.server_args.sleep_on_idle:
654
+ self.idle_sleeper = IdleSleeper(
655
+ [
656
+ self.recv_from_tokenizer,
657
+ self.recv_from_rpc,
658
+ ]
659
+ )
660
+ else:
661
+ self.recv_from_tokenizer = None
662
+ self.recv_from_rpc = None
663
+ self.send_to_tokenizer = SenderWrapper(None)
664
+ self.send_to_detokenizer = SenderWrapper(None)
665
+
666
+ if self.current_scheduler_metrics_enabled():
667
+ self.send_metrics_from_scheduler = get_zmq_socket(
668
+ context, zmq.PUSH, port_args.metrics_ipc_name, False
669
+ )
670
+
671
+ def init_deterministic_inference_config(self):
672
+ """Initialize deterministic inference configuration for different attention backends."""
673
+ if not self.server_args.enable_deterministic_inference:
674
+ self.truncation_align_size = None
675
+ return
676
+
677
+ backend_sizes = {
678
+ "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
679
+ "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
680
+ }
681
+ env_var, default_size = backend_sizes.get(
682
+ self.server_args.attention_backend, (None, None)
683
+ )
684
+ self.truncation_align_size = (
685
+ get_int_env_var(env_var, default_size) if env_var else None
686
+ )
687
+
586
688
  def init_tokenizer(self):
587
689
  server_args = self.server_args
588
690
  self.is_generation = self.model_config.is_generation
@@ -654,6 +756,7 @@ class Scheduler(
654
756
  else self.tp_cpu_group
655
757
  ),
656
758
  page_size=self.page_size,
759
+ eviction_policy=server_args.radix_eviction_policy,
657
760
  hicache_ratio=server_args.hicache_ratio,
658
761
  hicache_size=server_args.hicache_size,
659
762
  hicache_write_policy=server_args.hicache_write_policy,
@@ -664,29 +767,22 @@ class Scheduler(
664
767
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
665
768
  model_name=server_args.served_model_name,
666
769
  storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
770
+ is_eagle=self.spec_algorithm.is_eagle(),
667
771
  )
668
772
  self.tp_worker.register_hicache_layer_transfer_counter(
669
773
  self.tree_cache.cache_controller.layer_done_counter
670
774
  )
671
775
  elif self.is_hybrid:
672
- assert (
673
- self.server_args.disaggregation_mode == "null"
674
- ), "Hybrid mode does not support disaggregation yet"
675
776
  self.tree_cache = SWARadixCache(
676
777
  req_to_token_pool=self.req_to_token_pool,
677
778
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
678
779
  sliding_window_size=self.sliding_window_size,
679
780
  page_size=self.page_size,
680
781
  disable=server_args.disable_radix_cache,
782
+ is_eagle=self.spec_algorithm.is_eagle(),
681
783
  )
682
- elif self.enable_lora:
683
- assert (
684
- not self.enable_hierarchical_cache
685
- ), "LoRA radix cache doesn't support hierarchical cache"
686
- assert (
687
- self.schedule_policy == "fcfs"
688
- ), "LoRA radix cache only supports FCFS policy"
689
- self.tree_cache = LoRARadixCache(
784
+ elif self.is_hybrid_gdn:
785
+ self.tree_cache = MambaRadixCache(
690
786
  req_to_token_pool=self.req_to_token_pool,
691
787
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
692
788
  page_size=self.page_size,
@@ -706,6 +802,7 @@ class Scheduler(
706
802
  tp_size=self.tp_size,
707
803
  rank=self.tp_rank,
708
804
  tp_group=self.tp_group,
805
+ eviction_policy=server_args.radix_eviction_policy,
709
806
  )
710
807
  else:
711
808
  self.tree_cache = RadixCache(
@@ -714,16 +811,36 @@ class Scheduler(
714
811
  page_size=self.page_size,
715
812
  disable=server_args.disable_radix_cache,
716
813
  enable_kv_cache_events=self.enable_kv_cache_events,
814
+ eviction_policy=server_args.radix_eviction_policy,
815
+ is_eagle=self.spec_algorithm.is_eagle(),
717
816
  )
718
817
 
818
+ if (
819
+ server_args.disaggregation_mode == "decode"
820
+ and server_args.disaggregation_decode_enable_offload_kvcache
821
+ ):
822
+ self.decode_offload_manager = DecodeKVCacheOffloadManager(
823
+ req_to_token_pool=self.req_to_token_pool,
824
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
825
+ tp_group=(
826
+ self.attn_tp_cpu_group
827
+ if self.server_args.enable_dp_attention
828
+ else self.tp_cpu_group
829
+ ),
830
+ tree_cache=self.tree_cache,
831
+ server_args=self.server_args,
832
+ )
833
+ else:
834
+ self.decode_offload_manager = None
835
+
719
836
  self.decode_mem_cache_buf_multiplier = (
720
837
  1
721
838
  if self.spec_algorithm.is_none()
722
839
  else (
723
840
  server_args.speculative_num_draft_tokens
724
841
  + (
725
- server_args.speculative_eagle_topk
726
- * server_args.speculative_num_steps
842
+ (server_args.speculative_eagle_topk or 1)
843
+ * (server_args.speculative_num_steps or 1)
727
844
  )
728
845
  )
729
846
  )
@@ -746,7 +863,7 @@ class Scheduler(
746
863
  self.disagg_metadata_buffers = MetadataBuffers(
747
864
  buffer_size,
748
865
  hidden_size=self.model_config.hf_text_config.hidden_size,
749
- dtype=self.model_config.dtype,
866
+ hidden_states_dtype=self.model_config.dtype,
750
867
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
751
868
  )
752
869
 
@@ -766,7 +883,7 @@ class Scheduler(
766
883
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
767
884
  draft_token_to_kv_pool=(
768
885
  None
769
- if self.draft_worker is None
886
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
770
887
  else self.draft_worker.model_runner.token_to_kv_pool
771
888
  ),
772
889
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -795,7 +912,7 @@ class Scheduler(
795
912
  self.disagg_metadata_buffers = MetadataBuffers(
796
913
  buffer_size,
797
914
  hidden_size=self.model_config.hf_text_config.hidden_size,
798
- dtype=self.model_config.dtype,
915
+ hidden_states_dtype=self.model_config.dtype,
799
916
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
800
917
  )
801
918
 
@@ -803,7 +920,7 @@ class Scheduler(
803
920
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
804
921
  draft_token_to_kv_pool=(
805
922
  None
806
- if self.draft_worker is None
923
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
807
924
  else self.draft_worker.model_runner.token_to_kv_pool
808
925
  ),
809
926
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -824,6 +941,34 @@ class Scheduler(
824
941
  # The prefill requests that are in the middle of kv sending
825
942
  self.disagg_prefill_inflight_queue: List[Req] = []
826
943
 
944
+ def init_overlap(self):
945
+ if not self.enable_overlap:
946
+ return
947
+
948
+ self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
949
+ self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
950
+ self.device
951
+ ).stream(self.forward_stream)
952
+ self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
953
+ self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
954
+ self.device
955
+ ).stream(self.copy_stream)
956
+
957
+ self.future_map = FutureMap(
958
+ self.max_running_requests, self.device, self.spec_algorithm
959
+ )
960
+ self.batch_record_buf = [None] * 2
961
+ self.batch_record_ct = 0
962
+
963
+ def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
964
+ # FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
965
+ # NOTE: More Reliable: record all tensors into the forward stream
966
+ # NOTE: - for all future tensors, we shall always read from future map
967
+ # - for all non-future tensors (produced only by schedule stream),
968
+ # we shall keep its reference not being release during all the forwarding pass
969
+ self.batch_record_ct = (self.batch_record_ct + 1) % 2
970
+ self.batch_record_buf[self.batch_record_ct] = model_worker_batch
971
+
827
972
  def init_moe_config(self):
828
973
  if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
829
974
  initialize_moe_config(self.server_args)
@@ -838,10 +983,6 @@ class Scheduler(
838
983
  batch = self.get_next_batch_to_run()
839
984
  self.cur_batch = batch
840
985
 
841
- if batch:
842
- for req in batch.reqs:
843
- trace_event("schedule", req.rid)
844
-
845
986
  if batch:
846
987
  result = self.run_batch(batch)
847
988
  self.process_batch_result(batch, result)
@@ -854,7 +995,7 @@ class Scheduler(
854
995
  @DynamicGradMode()
855
996
  def event_loop_overlap(self):
856
997
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
857
- self.result_queue = deque()
998
+ self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
858
999
 
859
1000
  while True:
860
1001
  recv_reqs = self.recv_requests()
@@ -863,173 +1004,24 @@ class Scheduler(
863
1004
  batch = self.get_next_batch_to_run()
864
1005
  self.cur_batch = batch
865
1006
 
1007
+ batch_result = None
866
1008
  if batch:
867
- for req in batch.reqs:
868
- trace_event("schedule", req.rid)
869
-
870
- if batch:
871
- batch.launch_done = threading.Event()
872
- result = self.run_batch(batch)
873
- self.result_queue.append((batch.copy(), result))
874
-
875
- if self.last_batch is None:
876
- # Create a dummy first batch to start the pipeline for overlap schedule.
877
- # It is now used for triggering the sampling_info_done event.
878
- tmp_batch = ScheduleBatch(
879
- reqs=None,
880
- forward_mode=ForwardMode.DUMMY_FIRST,
881
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
882
- )
883
- self.process_batch_result(tmp_batch, None, batch.launch_done)
1009
+ batch_result = self.run_batch(batch)
1010
+ self.result_queue.append((batch.copy(), batch_result))
884
1011
 
885
1012
  if self.last_batch:
886
1013
  # Process the results of the last batch
887
1014
  tmp_batch, tmp_result = self.result_queue.popleft()
888
- tmp_batch.next_batch_sampling_info = (
889
- self.tp_worker.cur_sampling_info if batch else None
890
- )
891
- # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
892
- self.process_batch_result(
893
- tmp_batch, tmp_result, batch.launch_done if batch else None
894
- )
1015
+ self.process_batch_result(tmp_batch, tmp_result)
895
1016
  elif batch is None:
896
1017
  # When the server is idle, do self-check and re-init some states
897
1018
  self.self_check_during_idle()
898
1019
 
1020
+ self.launch_batch_sample_if_needed(batch_result)
899
1021
  self.last_batch = batch
900
1022
 
901
- @DynamicGradMode()
902
- def event_loop_pp(self):
903
- """A non-overlap scheduler loop for pipeline parallelism."""
904
- mbs = [None] * self.pp_size
905
- last_mbs = [None] * self.pp_size
906
- self.running_mbs = [
907
- ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
908
- ]
909
- bids = [None] * self.pp_size
910
- pp_outputs: Optional[PPProxyTensors] = None
911
- while True:
912
- server_is_idle = True
913
- for mb_id in range(self.pp_size):
914
- self.running_batch = self.running_mbs[mb_id]
915
- self.last_batch = last_mbs[mb_id]
916
-
917
- recv_reqs = self.recv_requests()
918
- self.process_input_requests(recv_reqs)
919
- mbs[mb_id] = self.get_next_batch_to_run()
920
- self.running_mbs[mb_id] = self.running_batch
921
-
922
- self.cur_batch = mbs[mb_id]
923
- if self.cur_batch:
924
- server_is_idle = False
925
- result = self.run_batch(self.cur_batch)
926
-
927
- # (last rank) send the outputs to the next step
928
- if self.pp_group.is_last_rank:
929
- if self.cur_batch:
930
- next_token_ids, bids[mb_id] = (
931
- result.next_token_ids,
932
- result.bid,
933
- )
934
- if self.cur_batch.return_logprob:
935
- pp_outputs = PPProxyTensors(
936
- {
937
- "next_token_ids": next_token_ids,
938
- "extend_input_len_per_req": result.extend_input_len_per_req,
939
- "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
940
- }
941
- | (
942
- {
943
- f"logits_output.{k}": v
944
- for k, v in result.logits_output.__dict__.items()
945
- }
946
- if result.logits_output is not None
947
- else {}
948
- )
949
- )
950
- else:
951
- pp_outputs = PPProxyTensors(
952
- {
953
- "next_token_ids": next_token_ids,
954
- }
955
- )
956
- # send the output from the last round to let the next stage worker run post processing
957
- self.pp_group.send_tensor_dict(
958
- pp_outputs.tensors,
959
- all_gather_group=self.attn_tp_group,
960
- )
961
-
962
- # receive outputs and post-process (filter finished reqs) the coming microbatch
963
- next_mb_id = (mb_id + 1) % self.pp_size
964
- next_pp_outputs = None
965
- if mbs[next_mb_id] is not None:
966
- next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
967
- self.pp_group.recv_tensor_dict(
968
- all_gather_group=self.attn_tp_group
969
- )
970
- )
971
- mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
972
- logits_output_args = {
973
- k[len("logits_output.") :]: v
974
- for k, v in next_pp_outputs.tensors.items()
975
- if k.startswith("logits_output.")
976
- }
977
- if len(logits_output_args) > 0:
978
- logits_output = LogitsProcessorOutput(**logits_output_args)
979
- else:
980
- logits_output = None
981
- output_result = GenerationBatchResult(
982
- logits_output=logits_output,
983
- pp_hidden_states_proxy_tensors=None,
984
- next_token_ids=next_pp_outputs["next_token_ids"],
985
- extend_input_len_per_req=next_pp_outputs.tensors.get(
986
- "extend_input_len_per_req", None
987
- ),
988
- extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
989
- "extend_logprob_start_len_per_req", None
990
- ),
991
- bid=bids[next_mb_id],
992
- can_run_cuda_graph=result.can_run_cuda_graph,
993
- )
994
- self.process_batch_result(mbs[next_mb_id], output_result)
995
- last_mbs[next_mb_id] = mbs[next_mb_id]
996
-
997
- # (not last rank)
998
- if not self.pp_group.is_last_rank:
999
- if self.cur_batch:
1000
- bids[mb_id] = result.bid
1001
- # carry the outputs to the next stage
1002
- # send the outputs from the last round to let the next stage worker run post processing
1003
- if pp_outputs:
1004
- self.pp_group.send_tensor_dict(
1005
- pp_outputs.tensors,
1006
- all_gather_group=self.attn_tp_group,
1007
- )
1008
-
1009
- # send out reqs to the next stage
1010
- dp_offset = self.attn_dp_rank * self.attn_tp_size
1011
- if self.attn_tp_rank == 0:
1012
- point_to_point_pyobj(
1013
- recv_reqs,
1014
- self.pp_rank * self.tp_size + dp_offset,
1015
- self.world_group.device_group,
1016
- self.pp_rank * self.tp_size + dp_offset,
1017
- (self.pp_rank + 1) * self.tp_size + dp_offset,
1018
- )
1019
-
1020
- # send out proxy tensors to the next stage
1021
- if self.cur_batch:
1022
- self.pp_group.send_tensor_dict(
1023
- result.pp_hidden_states_proxy_tensors,
1024
- all_gather_group=self.attn_tp_group,
1025
- )
1026
-
1027
- pp_outputs = next_pp_outputs
1028
-
1029
- # When the server is idle, self-check and re-init some states
1030
- if server_is_idle:
1031
- # When the server is idle, do self-check and re-init some states
1032
- self.self_check_during_idle()
1023
+ if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
1024
+ self._check_runtime_mem_leak()
1033
1025
 
1034
1026
  def recv_requests(self) -> List[Req]:
1035
1027
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
@@ -1131,10 +1123,13 @@ class Scheduler(
1131
1123
  src=self.tp_group.ranks[0],
1132
1124
  )
1133
1125
 
1134
- for req in recv_reqs:
1135
- if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
1136
- trace_set_proc_propagate_context(req.rid, req.trace_context)
1137
- trace_slice_start("", req.rid, anonymous=True)
1126
+ if self.enable_trace:
1127
+ for req in recv_reqs:
1128
+ if isinstance(
1129
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1130
+ ):
1131
+ trace_set_proc_propagate_context(req.rid, req.trace_context)
1132
+ trace_slice_start("", req.rid, anonymous=True)
1138
1133
 
1139
1134
  return recv_reqs
1140
1135
 
@@ -1149,37 +1144,13 @@ class Scheduler(
1149
1144
  self.return_health_check_ct += 1
1150
1145
  continue
1151
1146
 
1152
- # If it is a work request, accept or reject the request based on the request queue size.
1153
- if is_work_request(recv_req):
1154
- if len(self.waiting_queue) + 1 > self.max_queued_requests:
1155
- abort_req = AbortReq(
1156
- recv_req.rid,
1157
- finished_reason={
1158
- "type": "abort",
1159
- "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1160
- "message": "The request queue is full.",
1161
- },
1162
- )
1163
- self.send_to_tokenizer.send_pyobj(abort_req)
1164
- continue
1165
-
1166
- # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
1167
- if isinstance(recv_req, MultiTokenizerWrapper):
1168
- worker_id = recv_req.worker_id
1169
- recv_req = recv_req.obj
1170
- output = self._request_dispatcher(recv_req)
1171
- if output is not None:
1172
- output = MultiTokenizerWrapper(worker_id, output)
1173
- self.send_to_tokenizer.send_pyobj(output)
1174
- continue
1175
-
1176
1147
  output = self._request_dispatcher(recv_req)
1177
1148
  if output is not None:
1178
1149
  if isinstance(output, RpcReqOutput):
1179
1150
  if self.recv_from_rpc is not None:
1180
1151
  self.recv_from_rpc.send_pyobj(output)
1181
1152
  else:
1182
- self.send_to_tokenizer.send_pyobj(output)
1153
+ self.send_to_tokenizer.send_output(output, recv_req)
1183
1154
 
1184
1155
  def init_req_max_new_tokens(self, req):
1185
1156
  req.sampling_params.max_new_tokens = min(
@@ -1195,8 +1166,6 @@ class Scheduler(
1195
1166
  self,
1196
1167
  recv_req: TokenizedGenerateReqInput,
1197
1168
  ):
1198
- self.maybe_update_dp_balance_data(recv_req)
1199
-
1200
1169
  # Create a new request
1201
1170
  if (
1202
1171
  recv_req.session_params is None
@@ -1230,8 +1199,14 @@ class Scheduler(
1230
1199
  bootstrap_host=recv_req.bootstrap_host,
1231
1200
  bootstrap_port=recv_req.bootstrap_port,
1232
1201
  bootstrap_room=recv_req.bootstrap_room,
1202
+ disagg_mode=self.disaggregation_mode,
1233
1203
  data_parallel_rank=recv_req.data_parallel_rank,
1234
1204
  vocab_size=self.model_config.vocab_size,
1205
+ priority=recv_req.priority,
1206
+ metrics_collector=(
1207
+ self.metrics_collector if self.enable_metrics else None
1208
+ ),
1209
+ http_worker_ipc=recv_req.http_worker_ipc,
1235
1210
  )
1236
1211
  req.tokenizer = self.tokenizer
1237
1212
 
@@ -1330,29 +1305,31 @@ class Scheduler(
1330
1305
  or req.sampling_params.ebnf is not None
1331
1306
  or req.sampling_params.structural_tag is not None
1332
1307
  ):
1333
- assert self.grammar_backend is not None
1334
- if req.sampling_params.json_schema is not None:
1335
- key = ("json", req.sampling_params.json_schema)
1336
- elif req.sampling_params.regex is not None:
1337
- key = ("regex", req.sampling_params.regex)
1338
- elif req.sampling_params.ebnf is not None:
1339
- key = ("ebnf", req.sampling_params.ebnf)
1340
- elif req.sampling_params.structural_tag:
1341
- key = ("structural_tag", req.sampling_params.structural_tag)
1342
-
1343
- value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
1344
- req.grammar = value
1345
-
1346
- if not cache_hit:
1347
- req.grammar_key = key
1348
- add_to_grammar_queue = True
1308
+ if self.grammar_backend is None:
1309
+ error_msg = "Grammar-based generation (json_schema, regex, ebnf, structural_tag) is not supported when the server is launched with --grammar-backend none"
1310
+ req.set_finish_with_abort(error_msg)
1349
1311
  else:
1350
- if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
1351
- error_msg = f"Invalid grammar request with cache hit: {key=}"
1352
- req.set_finish_with_abort(error_msg)
1312
+ if req.sampling_params.json_schema is not None:
1313
+ key = ("json", req.sampling_params.json_schema)
1314
+ elif req.sampling_params.regex is not None:
1315
+ key = ("regex", req.sampling_params.regex)
1316
+ elif req.sampling_params.ebnf is not None:
1317
+ key = ("ebnf", req.sampling_params.ebnf)
1318
+ elif req.sampling_params.structural_tag:
1319
+ key = ("structural_tag", req.sampling_params.structural_tag)
1320
+
1321
+ value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
1322
+ req.grammar = value
1323
+
1324
+ if not cache_hit:
1325
+ req.grammar_key = key
1326
+ add_to_grammar_queue = True
1327
+ else:
1328
+ if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
1329
+ error_msg = f"Invalid grammar request with cache hit: {key=}"
1330
+ req.set_finish_with_abort(error_msg)
1353
1331
 
1354
1332
  if add_to_grammar_queue:
1355
- req.queue_time_start = time.perf_counter()
1356
1333
  self.grammar_queue.append(req)
1357
1334
  else:
1358
1335
  self._add_request_to_queue(req)
@@ -1368,20 +1345,6 @@ class Scheduler(
1368
1345
  for tokenized_req in recv_req:
1369
1346
  self.handle_generate_request(tokenized_req)
1370
1347
 
1371
- def _add_request_to_queue(self, req: Req):
1372
- req.queue_time_start = time.perf_counter()
1373
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1374
- self._prefetch_kvcache(req)
1375
- self.disagg_prefill_bootstrap_queue.add(
1376
- req, self.model_config.num_key_value_heads
1377
- )
1378
- elif self.disaggregation_mode == DisaggregationMode.DECODE:
1379
- self.disagg_decode_prealloc_queue.add(req)
1380
- else:
1381
- self._prefetch_kvcache(req)
1382
- self.waiting_queue.append(req)
1383
- trace_slice_end("process req", req.rid, auto_next_anon=True)
1384
-
1385
1348
  def _prefetch_kvcache(self, req: Req):
1386
1349
  if self.enable_hicache_storage:
1387
1350
  req.init_next_round_input(self.tree_cache)
@@ -1391,20 +1354,106 @@ class Scheduler(
1391
1354
  last_hash = req.last_host_node.get_last_hash_value()
1392
1355
  matched_len = len(req.prefix_indices) + req.host_hit_length
1393
1356
  new_input_tokens = req.fill_ids[matched_len:]
1357
+
1358
+ prefix_keys = (
1359
+ req.last_node.get_prefix_hash_values(req.last_node.parent)
1360
+ if self.tree_cache.hicache_storage_pass_prefix_keys
1361
+ else None
1362
+ )
1394
1363
  self.tree_cache.prefetch_from_storage(
1395
- req.rid, req.last_host_node, new_input_tokens, last_hash
1364
+ req.rid,
1365
+ req.last_host_node,
1366
+ new_input_tokens,
1367
+ last_hash,
1368
+ prefix_keys,
1396
1369
  )
1397
1370
 
1398
- def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1399
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1400
- self.disagg_prefill_bootstrap_queue.extend(
1401
- reqs, self.model_config.num_key_value_heads
1371
+ def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
1372
+ if self.disaggregation_mode == DisaggregationMode.NULL:
1373
+ self._set_or_validate_priority(req)
1374
+ if self._abort_on_queued_limit(req):
1375
+ return
1376
+ self._prefetch_kvcache(req)
1377
+ self.waiting_queue.append(req)
1378
+ req.time_stats.wait_queue_entry_time = time.perf_counter()
1379
+ trace_slice_end("process req", req.rid, auto_next_anon=True)
1380
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
1381
+ self._prefetch_kvcache(req)
1382
+ self.disagg_prefill_bootstrap_queue.add(
1383
+ req, self.model_config.num_key_value_heads
1402
1384
  )
1385
+ req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
1403
1386
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1404
- # If this is a decode server, we put the request to the decode pending prealloc queue
1405
- self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
1387
+ self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
1388
+ if not is_retracted:
1389
+ req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
1406
1390
  else:
1407
- self.waiting_queue.extend(reqs)
1391
+ raise ValueError(f"Invalid {self.disaggregation_mode=}")
1392
+
1393
+ def _set_or_validate_priority(self, req: Req):
1394
+ """Set the default priority value, or abort the request based on the priority scheduling mode."""
1395
+ if self.enable_priority_scheduling and req.priority is None:
1396
+ if self.schedule_low_priority_values_first:
1397
+ req.priority = sys.maxsize
1398
+ else:
1399
+ req.priority = -sys.maxsize - 1
1400
+ elif (
1401
+ not self.enable_priority_scheduling
1402
+ and req.priority is not None
1403
+ and self.abort_on_priority_when_disabled
1404
+ ):
1405
+ abort_req = AbortReq(
1406
+ finished_reason={
1407
+ "type": "abort",
1408
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1409
+ "message": "Using priority is disabled for this server. Please send a new request without a priority.",
1410
+ },
1411
+ rid=req.rid,
1412
+ )
1413
+ self.send_to_tokenizer.send_output(abort_req, req)
1414
+
1415
+ def _abort_on_queued_limit(self, recv_req: Req) -> bool:
1416
+ """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
1417
+ if (
1418
+ self.max_queued_requests is None
1419
+ or len(self.waiting_queue) + 1 <= self.max_queued_requests
1420
+ ):
1421
+ return False
1422
+
1423
+ # Reject the incoming request by default.
1424
+ req_to_abort = recv_req
1425
+ message = "The request queue is full."
1426
+ if self.enable_priority_scheduling:
1427
+ # With priority scheduling, consider aboritng an existing request based on the priority.
1428
+ # direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
1429
+ # max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
1430
+ # Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
1431
+ direction = 1 if self.schedule_low_priority_values_first else -1
1432
+ key_fn = lambda item: (
1433
+ direction * item[1].priority,
1434
+ item[1].time_stats.wait_queue_entry_time,
1435
+ )
1436
+ idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
1437
+ abort_existing_req = (
1438
+ direction * recv_req.priority < direction * candidate_req.priority
1439
+ )
1440
+ if abort_existing_req:
1441
+ self.waiting_queue.pop(idx)
1442
+ req_to_abort = candidate_req
1443
+ message = "The request is aborted by a higher priority request."
1444
+
1445
+ self.send_to_tokenizer.send_output(
1446
+ AbortReq(
1447
+ finished_reason={
1448
+ "type": "abort",
1449
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1450
+ "message": message,
1451
+ },
1452
+ rid=req_to_abort.rid,
1453
+ ),
1454
+ req_to_abort,
1455
+ )
1456
+ return req_to_abort.rid == recv_req.rid
1408
1457
 
1409
1458
  def handle_embedding_request(
1410
1459
  self,
@@ -1416,6 +1465,8 @@ class Scheduler(
1416
1465
  recv_req.input_ids,
1417
1466
  recv_req.sampling_params,
1418
1467
  token_type_ids=recv_req.token_type_ids,
1468
+ priority=recv_req.priority,
1469
+ http_worker_ipc=recv_req.http_worker_ipc,
1419
1470
  )
1420
1471
  req.tokenizer = self.tokenizer
1421
1472
 
@@ -1465,109 +1516,6 @@ class Scheduler(
1465
1516
  for tokenized_req in recv_req:
1466
1517
  self.handle_embedding_request(tokenized_req)
1467
1518
 
1468
- def self_check_during_idle(self):
1469
- self.check_memory()
1470
- self.check_tree_cache()
1471
- self.new_token_ratio = self.init_new_token_ratio
1472
- self.maybe_sleep_on_idle()
1473
-
1474
- def check_memory(self):
1475
- if self.is_hybrid:
1476
- (
1477
- full_num_used,
1478
- swa_num_used,
1479
- _,
1480
- _,
1481
- full_available_size,
1482
- full_evictable_size,
1483
- swa_available_size,
1484
- swa_evictable_size,
1485
- ) = self._get_swa_token_info()
1486
- memory_leak = full_num_used != 0 or swa_num_used != 0
1487
- token_msg = (
1488
- f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
1489
- f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
1490
- )
1491
- else:
1492
- _, _, available_size, evictable_size = self._get_token_info()
1493
- protected_size = self.tree_cache.protected_size()
1494
- memory_leak = (available_size + evictable_size) != (
1495
- # self.max_total_num_tokens
1496
- # if not self.enable_hierarchical_cache
1497
- # else self.max_total_num_tokens - protected_size
1498
- self.max_total_num_tokens
1499
- - protected_size
1500
- )
1501
- token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
1502
-
1503
- if memory_leak:
1504
- msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
1505
- raise ValueError(msg)
1506
-
1507
- if self.disaggregation_mode == DisaggregationMode.DECODE:
1508
- req_total_size = (
1509
- self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
1510
- )
1511
- else:
1512
- req_total_size = self.req_to_token_pool.size
1513
-
1514
- if len(self.req_to_token_pool.free_slots) != req_total_size:
1515
- msg = (
1516
- "req_to_token_pool memory leak detected!"
1517
- f"available_size={len(self.req_to_token_pool.free_slots)}, "
1518
- f"total_size={self.req_to_token_pool.size}\n"
1519
- )
1520
- raise ValueError(msg)
1521
-
1522
- if (
1523
- self.enable_metrics
1524
- and self.current_scheduler_metrics_enabled()
1525
- and time.perf_counter() > self.metrics_collector.last_log_time + 30
1526
- ):
1527
- # During idle time, also collect metrics every 30 seconds.
1528
- if self.is_hybrid:
1529
- (
1530
- full_num_used,
1531
- swa_num_used,
1532
- full_token_usage,
1533
- swa_token_usage,
1534
- _,
1535
- _,
1536
- _,
1537
- _,
1538
- ) = self._get_swa_token_info()
1539
- num_used = max(full_num_used, swa_num_used)
1540
- token_usage = max(full_token_usage, swa_token_usage)
1541
- else:
1542
- num_used, token_usage, _, _ = self._get_token_info()
1543
- num_running_reqs = len(self.running_batch.reqs)
1544
- self.stats.num_running_reqs = num_running_reqs
1545
- self.stats.num_used_tokens = num_used
1546
- self.stats.token_usage = round(token_usage, 2)
1547
- self.stats.gen_throughput = 0
1548
- self.stats.num_queue_reqs = len(self.waiting_queue)
1549
- self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1550
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1551
- self.stats.num_prefill_prealloc_queue_reqs = len(
1552
- self.disagg_prefill_bootstrap_queue.queue
1553
- )
1554
- self.stats.num_prefill_inflight_queue_reqs = len(
1555
- self.disagg_prefill_inflight_queue
1556
- )
1557
- if self.disaggregation_mode == DisaggregationMode.DECODE:
1558
- self.stats.num_decode_prealloc_queue_reqs = len(
1559
- self.disagg_decode_prealloc_queue.queue
1560
- )
1561
- self.stats.num_decode_transfer_queue_reqs = len(
1562
- self.disagg_decode_transfer_queue.queue
1563
- )
1564
- self.metrics_collector.log_stats(self.stats)
1565
- self._publish_kv_events()
1566
-
1567
- def check_tree_cache(self):
1568
- if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
1569
- self.tree_cache.sanity_check()
1570
-
1571
1519
  def _get_token_info(self):
1572
1520
  available_size = self.token_to_kv_pool_allocator.available_size()
1573
1521
  evictable_size = self.tree_cache.evictable_size()
@@ -1575,6 +1523,35 @@ class Scheduler(
1575
1523
  token_usage = num_used / self.max_total_num_tokens
1576
1524
  return num_used, token_usage, available_size, evictable_size
1577
1525
 
1526
+ def _get_mamba_token_info(self):
1527
+ is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
1528
+ full_available_size = self.token_to_kv_pool_allocator.available_size()
1529
+ full_evictable_size = (
1530
+ self.tree_cache.full_evictable_size() if is_radix_tree else 0
1531
+ )
1532
+ mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
1533
+ mamba_evictable_size = (
1534
+ self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
1535
+ )
1536
+ full_num_used = self.token_to_kv_pool_allocator.size - (
1537
+ full_available_size + full_evictable_size
1538
+ )
1539
+ mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
1540
+ mamba_available_size + mamba_evictable_size
1541
+ )
1542
+ full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
1543
+ mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
1544
+ return (
1545
+ full_num_used,
1546
+ mamba_num_used,
1547
+ full_token_usage,
1548
+ mamba_usage,
1549
+ full_available_size,
1550
+ full_evictable_size,
1551
+ mamba_available_size,
1552
+ mamba_evictable_size,
1553
+ )
1554
+
1578
1555
  def _get_swa_token_info(self):
1579
1556
  full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1580
1557
  full_evictable_size = self.tree_cache.full_evictable_size()
@@ -1608,7 +1585,7 @@ class Scheduler(
1608
1585
  chunked_req_to_exclude.add(self.chunked_req)
1609
1586
  self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1610
1587
  # chunked request keeps its rid but will get a new req_pool_idx
1611
- if self.tp_worker.worker.model_runner.is_hybrid_gdn:
1588
+ if self.tp_worker.model_runner.mambaish_config is not None:
1612
1589
  self.req_to_token_pool.free(
1613
1590
  self.chunked_req.req_pool_idx, free_mamba_cache=False
1614
1591
  )
@@ -1660,13 +1637,12 @@ class Scheduler(
1660
1637
 
1661
1638
  # Handle DP attention
1662
1639
  if need_dp_attn_preparation:
1663
- self.maybe_handle_dp_balance_data()
1664
1640
  ret = self.prepare_mlp_sync_batch(ret)
1665
1641
 
1666
1642
  return ret
1667
1643
 
1668
1644
  def get_num_allocatable_reqs(self, running_bs):
1669
- res = global_server_args_dict["max_micro_batch_size"] - running_bs
1645
+ res = get_global_server_args().pp_max_micro_batch_size - running_bs
1670
1646
  if self.pp_size > 1:
1671
1647
  res = min(res, self.req_to_token_pool.available_size())
1672
1648
  return res
@@ -1676,6 +1652,10 @@ class Scheduler(
1676
1652
  if self.grammar_queue:
1677
1653
  self.move_ready_grammar_requests()
1678
1654
 
1655
+ if self.try_preemption:
1656
+ # Reset batch_is_full to try preemption with a prefill adder.
1657
+ self.running_batch.batch_is_full = False
1658
+
1679
1659
  # Handle the cases where prefill is not allowed
1680
1660
  if (
1681
1661
  self.running_batch.batch_is_full or len(self.waiting_queue) == 0
@@ -1688,7 +1668,11 @@ class Scheduler(
1688
1668
  # as the space for the chunked request has just been released.
1689
1669
  # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
1690
1670
  # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1691
- if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
1671
+ if (
1672
+ self.get_num_allocatable_reqs(running_bs) <= 0
1673
+ and not self.chunked_req
1674
+ and not self.try_preemption
1675
+ ):
1692
1676
  self.running_batch.batch_is_full = True
1693
1677
  return None
1694
1678
 
@@ -1708,6 +1692,7 @@ class Scheduler(
1708
1692
  self.max_prefill_tokens,
1709
1693
  self.chunked_prefill_size,
1710
1694
  running_bs if self.is_mixed_chunk else 0,
1695
+ self.priority_scheduling_preemption_threshold,
1711
1696
  )
1712
1697
 
1713
1698
  if self.chunked_req is not None:
@@ -1728,15 +1713,19 @@ class Scheduler(
1728
1713
  self.running_batch.batch_is_full = True
1729
1714
  break
1730
1715
 
1716
+ running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
1731
1717
  if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
1732
1718
  self.running_batch.batch_is_full = True
1733
- break
1734
-
1735
1719
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1736
1720
  # In prefill mode, prealloc queue and transfer queue can also take memory,
1737
1721
  # so we need to check if the available size for the actual available size.
1738
1722
  if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
1739
1723
  self.running_batch.batch_is_full = True
1724
+
1725
+ if self.running_batch.batch_is_full:
1726
+ if not self.try_preemption:
1727
+ break
1728
+ if not adder.preempt_to_schedule(req, self.server_args):
1740
1729
  break
1741
1730
 
1742
1731
  if self.enable_hicache_storage:
@@ -1746,7 +1735,11 @@ class Scheduler(
1746
1735
  continue
1747
1736
 
1748
1737
  req.init_next_round_input(self.tree_cache)
1749
- res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1738
+ res = adder.add_one_req(
1739
+ req,
1740
+ has_chunked_req=(self.chunked_req is not None),
1741
+ truncation_align_size=self.truncation_align_size,
1742
+ )
1750
1743
 
1751
1744
  if res != AddReqResult.CONTINUE:
1752
1745
  if res == AddReqResult.NO_TOKEN:
@@ -1767,11 +1760,14 @@ class Scheduler(
1767
1760
  if self.enable_metrics:
1768
1761
  # only record queue time when enable_metrics is True to avoid overhead
1769
1762
  for req in can_run_list:
1770
- req.queue_time_end = time.perf_counter()
1763
+ req.add_latency(RequestStage.PREFILL_WAITING)
1771
1764
 
1772
1765
  self.waiting_queue = [
1773
1766
  x for x in self.waiting_queue if x not in set(can_run_list)
1774
1767
  ]
1768
+ if adder.preempt_list:
1769
+ for req in adder.preempt_list:
1770
+ self._add_request_to_queue(req)
1775
1771
 
1776
1772
  if adder.new_chunked_req is not None:
1777
1773
  assert self.chunked_req is None
@@ -1782,7 +1778,16 @@ class Scheduler(
1782
1778
 
1783
1779
  # Print stats
1784
1780
  if self.current_scheduler_metrics_enabled():
1785
- self.log_prefill_stats(adder, can_run_list, running_bs)
1781
+ self.log_prefill_stats(adder, can_run_list, running_bs, 0)
1782
+
1783
+ for req in can_run_list:
1784
+ if req.time_stats.forward_entry_time == 0:
1785
+ # Avoid update chunked request many times
1786
+ req.time_stats.forward_entry_time = time.perf_counter()
1787
+ if self.enable_metrics:
1788
+ self.metrics_collector.observe_queue_time(
1789
+ req.time_stats.get_queueing_time(),
1790
+ )
1786
1791
 
1787
1792
  # Create a new batch
1788
1793
  new_batch = ScheduleBatch.init_new(
@@ -1834,22 +1839,28 @@ class Scheduler(
1834
1839
 
1835
1840
  # Check if decode out of memory
1836
1841
  if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1837
- TEST_RETRACT and batch.batch_size() > 10
1842
+ TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
1838
1843
  ):
1839
1844
  old_ratio = self.new_token_ratio
1840
-
1841
- retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1842
- num_retracted_reqs = len(retracted_reqs)
1845
+ retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
1846
+ self.server_args
1847
+ )
1848
+ self.num_retracted_reqs = len(retracted_reqs)
1843
1849
  self.new_token_ratio = new_token_ratio
1850
+ for req in reqs_to_abort:
1851
+ self.send_to_tokenizer.send_output(
1852
+ AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
1853
+ )
1844
1854
 
1845
1855
  logger.info(
1846
1856
  "KV cache pool is full. Retract requests. "
1847
- f"#retracted_reqs: {num_retracted_reqs}, "
1848
- f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
1857
+ f"#retracted_reqs: {len(retracted_reqs)}, "
1858
+ f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
1859
+ f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
1849
1860
  )
1850
1861
 
1851
- self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1852
- self.total_retracted_reqs += num_retracted_reqs
1862
+ for req in retracted_reqs:
1863
+ self._add_request_to_queue(req, is_retracted=True)
1853
1864
  else:
1854
1865
  self.new_token_ratio = max(
1855
1866
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -1863,6 +1874,12 @@ class Scheduler(
1863
1874
  batch.prepare_for_decode()
1864
1875
  return batch
1865
1876
 
1877
+ # placeholder for override
1878
+ def update_cache_from_scheduler(
1879
+ self, schedule_batch: ScheduleBatch, batch_result: GenerationBatchResult
1880
+ ):
1881
+ pass
1882
+
1866
1883
  def run_batch(
1867
1884
  self, batch: ScheduleBatch
1868
1885
  ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
@@ -1877,33 +1894,75 @@ class Scheduler(
1877
1894
 
1878
1895
  # Run forward
1879
1896
  if self.is_generation:
1880
- if self.spec_algorithm.is_none():
1881
- model_worker_batch = batch.get_model_worker_batch()
1882
1897
 
1883
- if self.pp_group.is_last_rank:
1884
- logits_output, next_token_ids, can_run_cuda_graph = (
1885
- self.tp_worker.forward_batch_generation(model_worker_batch)
1886
- )
1887
- else:
1888
- pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1889
- self.tp_worker.forward_batch_generation(model_worker_batch)
1898
+ batch_or_worker_batch = batch
1899
+
1900
+ if self.enable_overlap or self.spec_algorithm.is_none():
1901
+ # FIXME(lsyin): remove this if and finally unify the abstraction
1902
+ batch_or_worker_batch = batch.get_model_worker_batch()
1903
+
1904
+ if self.enable_overlap:
1905
+ # FIXME: remove this assert
1906
+ assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
1907
+ model_worker_batch = batch_or_worker_batch
1908
+ self.record_batch_in_overlap(model_worker_batch)
1909
+
1910
+ # Sampling info will be modified during forward
1911
+ model_worker_batch.sampling_info = (
1912
+ model_worker_batch.sampling_info.copy_for_forward()
1913
+ )
1914
+
1915
+ bs = len(model_worker_batch.seq_lens)
1916
+ future_indices = self.future_map.alloc_future_indices(bs)
1917
+
1918
+ with self.forward_stream_ctx:
1919
+ self.forward_stream.wait_stream(self.default_stream)
1920
+ self.future_map.resolve_future(model_worker_batch)
1921
+ batch_result = self.model_worker.forward_batch_generation(
1922
+ model_worker_batch
1890
1923
  )
1891
- bid = model_worker_batch.bid
1924
+ # FIXME(lsyin): maybe move this to forward_batch_generation
1925
+ batch_result.copy_done = torch.get_device_module(
1926
+ self.device
1927
+ ).Event()
1928
+ if batch_result.delay_sample_func is None:
1929
+ self.future_map.store_to_map(future_indices, batch_result)
1930
+ batch_result.copy_to_cpu()
1931
+ else:
1932
+ batch_result.future_indices = future_indices
1933
+
1934
+ # FIXME(lsyin): move this assignment elsewhere
1935
+ future_indices_or_next_token_ids = -future_indices.indices
1936
+
1937
+ if batch.is_v2_eagle:
1938
+ # FIXME(lsyin): tmp code for eagle v2
1939
+ # We only keep future indices for next draft input
1940
+
1941
+ batch.spec_info = batch_result.next_draft_input
1942
+ batch.spec_info.future_indices = future_indices
1943
+
1944
+ # batch.spec_info = EagleDraftInput(
1945
+ # future_indices=future_indices,
1946
+ # verify_done=batch_result.next_draft_input.verify_done,
1947
+ # # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
1948
+ # allocate_lens=batch_result.next_draft_input.allocate_lens,
1949
+ # )
1950
+
1951
+ # The future value, usually for next batch preparation
1952
+ # Current implementation strictly synchronizes the seq_lens
1953
+ batch.seq_lens = batch_result.next_draft_input.new_seq_lens
1892
1954
  else:
1893
- (
1894
- logits_output,
1895
- next_token_ids,
1896
- bid,
1897
- num_accepted_tokens,
1898
- can_run_cuda_graph,
1899
- ) = self.draft_worker.forward_batch_speculative_generation(batch)
1900
- bs = batch.batch_size()
1901
- self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
1902
- self.spec_num_total_forward_ct += bs
1903
- self.num_generated_tokens += num_accepted_tokens
1904
-
1905
- if self.pp_group.is_last_rank:
1906
- batch.output_ids = next_token_ids
1955
+ batch_result = self.model_worker.forward_batch_generation(
1956
+ batch_or_worker_batch
1957
+ )
1958
+ future_indices_or_next_token_ids = batch_result.next_token_ids
1959
+ self.update_cache_from_scheduler(batch, batch_result)
1960
+
1961
+ # NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
1962
+ # which can probably be replaced by future_indices later [TODO(lsyin)].
1963
+ # we shall still keep the original outputs, e.g. next_token_ids
1964
+ # in the GenerationBatchOutput for processing after copy_done.
1965
+ batch.output_ids = future_indices_or_next_token_ids
1907
1966
 
1908
1967
  # These 2 values are needed for processing the output, but the values can be
1909
1968
  # modified by overlap schedule. So we have to copy them here so that
@@ -1912,6 +1971,7 @@ class Scheduler(
1912
1971
  extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1913
1972
  else:
1914
1973
  extend_input_len_per_req = None
1974
+
1915
1975
  if batch.return_logprob:
1916
1976
  extend_logprob_start_len_per_req = [
1917
1977
  req.extend_logprob_start_len for req in batch.reqs
@@ -1919,58 +1979,51 @@ class Scheduler(
1919
1979
  else:
1920
1980
  extend_logprob_start_len_per_req = None
1921
1981
 
1922
- ret = GenerationBatchResult(
1923
- logits_output=logits_output if self.pp_group.is_last_rank else None,
1924
- pp_hidden_states_proxy_tensors=(
1925
- pp_hidden_states_proxy_tensors
1926
- if not self.pp_group.is_last_rank
1927
- else None
1928
- ),
1929
- next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1930
- extend_input_len_per_req=extend_input_len_per_req,
1931
- extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1932
- bid=bid,
1933
- can_run_cuda_graph=can_run_cuda_graph,
1982
+ batch_result.extend_input_len_per_req = extend_input_len_per_req
1983
+ batch_result.extend_logprob_start_len_per_req = (
1984
+ extend_logprob_start_len_per_req
1934
1985
  )
1986
+ return batch_result
1935
1987
  else: # embedding or reward model
1936
1988
  model_worker_batch = batch.get_model_worker_batch()
1937
1989
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1938
- ret = EmbeddingBatchResult(
1939
- embeddings=embeddings, bid=model_worker_batch.bid
1940
- )
1990
+ ret = EmbeddingBatchResult(embeddings=embeddings)
1941
1991
  return ret
1942
1992
 
1993
+ def launch_batch_sample_if_needed(
1994
+ self, batch_result: GenerationBatchResult
1995
+ ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1996
+ # TODO(lsyin): make the delayed sample a default behavior after
1997
+ # unifying the forward_batch_generation interface (related to spec V2).
1998
+ if batch_result is None or batch_result.delay_sample_func is None:
1999
+ return
2000
+
2001
+ with self.forward_stream_ctx:
2002
+ self.forward_stream.wait_stream(self.default_stream)
2003
+ _batch_result = batch_result.delay_sample_func()
2004
+ assert _batch_result is batch_result
2005
+ self.future_map.store_to_map(batch_result.future_indices, batch_result)
2006
+ batch_result.copy_to_cpu()
2007
+
1943
2008
  def process_batch_result(
1944
2009
  self,
1945
2010
  batch: ScheduleBatch,
1946
2011
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
1947
- launch_done: Optional[threading.Event] = None,
1948
2012
  ):
1949
2013
  if batch.forward_mode.is_decode():
1950
- self.process_batch_result_decode(batch, result, launch_done)
1951
- for req in batch.reqs:
1952
- trace_slice(
1953
- "decode loop",
1954
- req.rid,
1955
- auto_next_anon=not req.finished(),
1956
- thread_finish_flag=req.finished(),
1957
- )
2014
+ self.process_batch_result_decode(batch, result)
2015
+ if self.enable_trace:
2016
+ trace_slice_batch("decode loop", batch.reqs)
1958
2017
 
1959
2018
  elif batch.forward_mode.is_extend():
1960
- self.process_batch_result_prefill(batch, result, launch_done)
1961
- for req in batch.reqs:
1962
- trace_slice(
1963
- "prefill",
1964
- req.rid,
1965
- auto_next_anon=not req.finished(),
1966
- thread_finish_flag=req.finished(),
1967
- )
2019
+ self.process_batch_result_prefill(batch, result)
2020
+ if self.enable_trace:
2021
+ trace_slice_batch("prefill", batch.reqs)
2022
+
1968
2023
  elif batch.forward_mode.is_idle():
1969
2024
  if self.enable_overlap:
1970
- self.tp_worker.resolve_last_batch_result(launch_done)
1971
- self.set_next_batch_sampling_info_done(batch)
1972
- elif batch.forward_mode.is_dummy_first():
1973
- self.set_next_batch_sampling_info_done(batch)
2025
+ if result.copy_done is not None:
2026
+ result.copy_done.synchronize()
1974
2027
 
1975
2028
  self.maybe_send_health_check_signal()
1976
2029
 
@@ -1980,7 +2033,7 @@ class Scheduler(
1980
2033
  # This is used to prevent the health check signal being blocked by long context prefill.
1981
2034
  # However, one minor issue is that this code path does not check the status of detokenizer manager.
1982
2035
  self.return_health_check_ct -= 1
1983
- self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
2036
+ self.send_to_tokenizer.send_output(HealthCheckOutput())
1984
2037
 
1985
2038
  def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
1986
2039
  return self.prepare_mlp_sync_batch_raw(
@@ -1994,6 +2047,7 @@ class Scheduler(
1994
2047
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1995
2048
  require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1996
2049
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
2050
+ offload_tags=self.offload_tags,
1997
2051
  )
1998
2052
 
1999
2053
  @staticmethod
@@ -2008,6 +2062,7 @@ class Scheduler(
2008
2062
  speculative_num_draft_tokens,
2009
2063
  require_mlp_tp_gather: bool,
2010
2064
  disable_overlap_schedule: bool,
2065
+ offload_tags: set[str],
2011
2066
  ):
2012
2067
  # Check if other DP workers have running batches
2013
2068
  if local_batch is None:
@@ -2038,7 +2093,7 @@ class Scheduler(
2038
2093
  )
2039
2094
 
2040
2095
  tbo_preparer = TboDPAttentionPreparer()
2041
- if disable_overlap_schedule:
2096
+ if len(offload_tags) == 0 and disable_overlap_schedule:
2042
2097
  group = tp_group.device_group
2043
2098
  device = tp_group.device
2044
2099
  else:
@@ -2123,12 +2178,13 @@ class Scheduler(
2123
2178
  if req.finished(): # It is aborted by AbortReq
2124
2179
  num_ready_reqs += 1
2125
2180
  continue
2181
+
2126
2182
  req.grammar = req.grammar.result(timeout=0.03)
2127
2183
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2128
2184
  if req.grammar is INVALID_GRAMMAR_OBJ:
2129
- req.set_finish_with_abort(
2130
- f"Invalid grammar request: {req.grammar_key=}"
2131
- )
2185
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2186
+ req.set_finish_with_abort(error_msg)
2187
+
2132
2188
  num_ready_reqs += 1
2133
2189
  except futures._base.TimeoutError:
2134
2190
  req.grammar_wait_ct += 1
@@ -2160,9 +2216,8 @@ class Scheduler(
2160
2216
  req.grammar = req.grammar.result()
2161
2217
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2162
2218
  if req.grammar is INVALID_GRAMMAR_OBJ:
2163
- req.set_finish_with_abort(
2164
- f"Invalid grammar request: {req.grammar_key=}"
2165
- )
2219
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2220
+ req.set_finish_with_abort(error_msg)
2166
2221
  else:
2167
2222
  num_ready_reqs_max = num_ready_reqs
2168
2223
  num_timeout_reqs_max = num_timeout_reqs
@@ -2170,21 +2225,16 @@ class Scheduler(
2170
2225
  for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
2171
2226
  req = self.grammar_queue[i]
2172
2227
  req.grammar.cancel()
2228
+ self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2173
2229
  error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
2174
2230
  req.set_finish_with_abort(error_msg)
2175
- self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2231
+
2176
2232
  num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2177
2233
 
2178
- self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2234
+ for req in self.grammar_queue[:num_ready_reqs]:
2235
+ self._add_request_to_queue(req)
2179
2236
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
2180
2237
 
2181
- def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
2182
- if batch.next_batch_sampling_info:
2183
- if batch.next_batch_sampling_info.grammars is not None:
2184
- batch.next_batch_sampling_info.update_regex_vocab_mask()
2185
- self.current_stream.synchronize()
2186
- batch.next_batch_sampling_info.sampling_info_done.set()
2187
-
2188
2238
  def watchdog_thread(self):
2189
2239
  """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
2190
2240
  self.watchdog_last_forward_ct = 0
@@ -2267,9 +2317,8 @@ class Scheduler(
2267
2317
  self.req_to_token_pool.clear()
2268
2318
  self.token_to_kv_pool_allocator.clear()
2269
2319
 
2270
- if not self.spec_algorithm.is_none():
2271
- self.draft_worker.model_runner.req_to_token_pool.clear()
2272
- self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2320
+ if self.draft_worker:
2321
+ self.draft_worker.clear_cache_pool()
2273
2322
 
2274
2323
  self.num_generated_tokens = 0
2275
2324
  self.forward_ct_decode = 0
@@ -2335,12 +2384,10 @@ class Scheduler(
2335
2384
  )
2336
2385
 
2337
2386
  def get_internal_state(self, recv_req: GetInternalStateReq):
2338
- ret = dict(global_server_args_dict)
2387
+ ret = vars(get_global_server_args())
2339
2388
  ret["last_gen_throughput"] = self.last_gen_throughput
2340
2389
  ret["memory_usage"] = {
2341
- "weight": round(
2342
- self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
2343
- ),
2390
+ "weight": round(self.tp_worker.model_runner.weight_load_mem_usage, 2),
2344
2391
  "kvcache": round(
2345
2392
  self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
2346
2393
  ),
@@ -2348,7 +2395,7 @@ class Scheduler(
2348
2395
  }
2349
2396
 
2350
2397
  ret["memory_usage"]["graph"] = round(
2351
- self.tp_worker.worker.model_runner.graph_mem_usage, 2
2398
+ self.tp_worker.model_runner.graph_mem_usage, 2
2352
2399
  )
2353
2400
 
2354
2401
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
@@ -2364,7 +2411,7 @@ class Scheduler(
2364
2411
  server_args_dict = recv_req.server_args
2365
2412
  args_allow_update = set(
2366
2413
  [
2367
- "max_micro_batch_size",
2414
+ "pp_max_micro_batch_size",
2368
2415
  "speculative_accept_threshold_single",
2369
2416
  "speculative_accept_threshold_acc",
2370
2417
  ]
@@ -2375,7 +2422,7 @@ class Scheduler(
2375
2422
  logging.warning(f"Updating {k} is not supported.")
2376
2423
  if_success = False
2377
2424
  break
2378
- elif k == "max_micro_batch_size" and (
2425
+ elif k == "pp_max_micro_batch_size" and (
2379
2426
  v > self.max_running_requests // self.pp_size or v < 1
2380
2427
  ):
2381
2428
  logging.warning(
@@ -2391,11 +2438,11 @@ class Scheduler(
2391
2438
  logger.info(f"{avg_spec_accept_length=}")
2392
2439
  self.cum_spec_accept_length = self.cum_spec_accept_count = 0
2393
2440
  for k, v in server_args_dict.items():
2394
- global_server_args_dict[k] = v
2395
- logger.info(f"Global server args updated! {global_server_args_dict=}")
2441
+ setattr(get_global_server_args(), k, v)
2442
+ logger.info(f"Global server args updated! {get_global_server_args()=}")
2396
2443
  return SetInternalStateReqOutput(
2397
2444
  updated=True,
2398
- server_args=global_server_args_dict,
2445
+ server_args=vars(get_global_server_args()),
2399
2446
  )
2400
2447
 
2401
2448
  def handle_rpc_request(self, recv_req: RpcReqInput):
@@ -2433,7 +2480,7 @@ class Scheduler(
2433
2480
  if self.enable_hicache_storage:
2434
2481
  # to release prefetch events associated with the request
2435
2482
  self.tree_cache.release_aborted_request(req.rid)
2436
- self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2483
+ self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
2437
2484
  # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2438
2485
  if self.disaggregation_mode == DisaggregationMode.DECODE:
2439
2486
  self.tree_cache.cache_finished_req(req)
@@ -2454,31 +2501,31 @@ class Scheduler(
2454
2501
  # Delete requests not in the waiting queue when PD disaggregation is enabled
2455
2502
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2456
2503
  # Abort requests that have not yet been bootstrapped
2457
- for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
2458
- logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2504
+ for req in self.disagg_prefill_bootstrap_queue.queue:
2459
2505
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2506
+ logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2460
2507
  if hasattr(req.disagg_kv_sender, "abort"):
2461
2508
  req.disagg_kv_sender.abort()
2462
2509
 
2463
2510
  # Abort in-flight requests
2464
- for i, req in enumerate(self.disagg_prefill_inflight_queue):
2465
- logger.debug(f"Abort inflight queue request. {req.rid=}")
2511
+ for req in self.disagg_prefill_inflight_queue:
2466
2512
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2513
+ logger.debug(f"Abort inflight queue request. {req.rid=}")
2467
2514
  if hasattr(req.disagg_kv_sender, "abort"):
2468
2515
  req.disagg_kv_sender.abort()
2469
2516
 
2470
2517
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
2471
2518
  # Abort requests that have not yet finished preallocation
2472
- for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
2473
- logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2519
+ for decode_req in self.disagg_decode_prealloc_queue.queue:
2474
2520
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2521
+ logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2475
2522
  if hasattr(decode_req.kv_receiver, "abort"):
2476
2523
  decode_req.kv_receiver.abort()
2477
2524
 
2478
2525
  # Abort requests waiting for kvcache to release tree cache
2479
- for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
2480
- logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2526
+ for decode_req in self.disagg_decode_transfer_queue.queue:
2481
2527
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2528
+ logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2482
2529
  if hasattr(decode_req.kv_receiver, "abort"):
2483
2530
  decode_req.kv_receiver.abort()
2484
2531
 
@@ -2517,10 +2564,6 @@ class Scheduler(
2517
2564
  result = self.tp_worker.unload_lora_adapter(recv_req)
2518
2565
  return result
2519
2566
 
2520
- def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
2521
- self.send_to_detokenizer.send_pyobj(recv_req)
2522
- return recv_req
2523
-
2524
2567
  def init_weights_send_group_for_remote_instance(
2525
2568
  self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
2526
2569
  ):
@@ -2545,11 +2588,12 @@ class Scheduler(
2545
2588
  return SlowDownReqOutput()
2546
2589
 
2547
2590
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2548
- if recv_req == ExpertDistributionReq.START_RECORD:
2591
+ action = recv_req.action
2592
+ if action == ExpertDistributionReqType.START_RECORD:
2549
2593
  get_global_expert_distribution_recorder().start_record()
2550
- elif recv_req == ExpertDistributionReq.STOP_RECORD:
2594
+ elif action == ExpertDistributionReqType.STOP_RECORD:
2551
2595
  get_global_expert_distribution_recorder().stop_record()
2552
- elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2596
+ elif action == ExpertDistributionReqType.DUMP_RECORD:
2553
2597
  get_global_expert_distribution_recorder().dump_record()
2554
2598
  else:
2555
2599
  raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
@@ -2598,7 +2642,7 @@ class Scheduler(
2598
2642
  def handle_freeze_gc(self, recv_req: FreezeGCReq):
2599
2643
  """Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
2600
2644
  freeze_gc("Scheduler")
2601
- self.send_to_detokenizer.send_pyobj(recv_req)
2645
+ self.send_to_detokenizer.send_output(recv_req, recv_req)
2602
2646
  return None
2603
2647
 
2604
2648
 
@@ -2620,19 +2664,21 @@ class IdleSleeper:
2620
2664
  for s in sockets:
2621
2665
  self.poller.register(s, zmq.POLLIN)
2622
2666
 
2667
+ self.empty_cache_interval = envs.SGLANG_EMPTY_CACHE_INTERVAL.get()
2668
+
2623
2669
  def maybe_sleep(self):
2624
2670
  self.poller.poll(1000)
2625
2671
  if (
2626
- global_config.torch_empty_cache_interval > 0
2627
- and time.time() - self.last_empty_time
2628
- > global_config.torch_empty_cache_interval
2672
+ self.empty_cache_interval > 0
2673
+ and time.time() - self.last_empty_time > self.empty_cache_interval
2629
2674
  ):
2630
2675
  self.last_empty_time = time.time()
2631
2676
  torch.cuda.empty_cache()
2632
2677
 
2633
2678
 
2634
2679
  def is_health_check_generate_req(recv_req):
2635
- return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2680
+ rid = getattr(recv_req, "rid", None)
2681
+ return rid is not None and rid.startswith("HEALTH_CHECK")
2636
2682
 
2637
2683
 
2638
2684
  def is_work_request(recv_req):
@@ -2656,19 +2702,12 @@ def run_scheduler_process(
2656
2702
  pp_rank: int,
2657
2703
  dp_rank: Optional[int],
2658
2704
  pipe_writer,
2659
- balance_meta: Optional[DPBalanceMeta] = None,
2660
2705
  ):
2661
- if server_args.enable_trace:
2662
- process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2663
- if server_args.disaggregation_mode == "null":
2664
- thread_label = "Scheduler"
2665
- trace_set_thread_info(thread_label, tp_rank, dp_rank)
2666
-
2667
- if (numa_node := server_args.numa_node) is not None:
2668
- numa_bind_to_node(numa_node[gpu_id])
2669
-
2670
- # Generate the prefix
2706
+ # Generate the logger prefix
2671
2707
  prefix = ""
2708
+ if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2709
+ # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2710
+ dp_rank = int(os.environ["SGLANG_DP_RANK"])
2672
2711
  if dp_rank is not None:
2673
2712
  prefix += f" DP{dp_rank}"
2674
2713
  if server_args.tp_size > 1:
@@ -2684,17 +2723,24 @@ def run_scheduler_process(
2684
2723
  kill_itself_when_parent_died()
2685
2724
  parent_process = psutil.Process().parent()
2686
2725
 
2687
- # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2688
- if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2689
- dp_rank = int(os.environ["SGLANG_DP_RANK"])
2690
-
2691
2726
  # Configure the logger
2692
2727
  configure_logger(server_args, prefix=prefix)
2693
2728
  suppress_other_loggers()
2694
2729
 
2695
2730
  # Set cpu affinity to this gpu process
2696
2731
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2697
- set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2732
+ set_gpu_proc_affinity(
2733
+ server_args.pp_size, server_args.tp_size, server_args.nnodes, gpu_id
2734
+ )
2735
+ if (numa_node := server_args.numa_node) is not None:
2736
+ numa_bind_to_node(numa_node[gpu_id])
2737
+
2738
+ # Set up tracing
2739
+ if server_args.enable_trace:
2740
+ process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2741
+ if server_args.disaggregation_mode == "null":
2742
+ thread_label = "Scheduler"
2743
+ trace_set_thread_info(thread_label, tp_rank, dp_rank)
2698
2744
 
2699
2745
  # Create a scheduler and run the event loop
2700
2746
  try:
@@ -2706,7 +2752,6 @@ def run_scheduler_process(
2706
2752
  moe_ep_rank,
2707
2753
  pp_rank,
2708
2754
  dp_rank,
2709
- dp_balance_meta=balance_meta,
2710
2755
  )
2711
2756
  pipe_writer.send(
2712
2757
  {