sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
44
44
  DecodeTransferQueue,
45
45
  SchedulerDisaggregationDecodeMixin,
46
46
  )
47
+ from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
48
+ DecodeKVCacheOffloadManager,
49
+ )
47
50
  from sglang.srt.disaggregation.prefill import (
48
51
  PrefillBootstrapQueue,
49
52
  SchedulerDisaggregationPrefillMixin,
@@ -57,11 +60,6 @@ from sglang.srt.disaggregation.utils import (
57
60
  )
58
61
  from sglang.srt.distributed import get_pp_group, get_world_group
59
62
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
60
- from sglang.srt.hf_transformers_utils import (
61
- get_processor,
62
- get_tokenizer,
63
- get_tokenizer_from_processor,
64
- )
65
63
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
64
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
65
  from sglang.srt.layers.moe import initialize_moe_config
@@ -72,8 +70,10 @@ from sglang.srt.managers.io_struct import (
72
70
  ClearHiCacheReqInput,
73
71
  ClearHiCacheReqOutput,
74
72
  CloseSessionReqInput,
73
+ DestroyWeightsUpdateGroupReqInput,
75
74
  ExpertDistributionReq,
76
75
  ExpertDistributionReqOutput,
76
+ ExpertDistributionReqType,
77
77
  FlushCacheReqInput,
78
78
  FlushCacheReqOutput,
79
79
  FreezeGCReq,
@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
116
116
  FINISH_ABORT,
117
117
  MultimodalInputs,
118
118
  Req,
119
+ RequestStage,
119
120
  ScheduleBatch,
120
121
  global_server_args_dict,
121
122
  )
@@ -140,23 +141,25 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
140
141
  from sglang.srt.managers.session_controller import Session
141
142
  from sglang.srt.managers.tp_worker import TpModelWorker
142
143
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
143
- from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
144
+ from sglang.srt.managers.utils import validate_input_length
144
145
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
145
146
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
146
- from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
147
147
  from sglang.srt.mem_cache.radix_cache import RadixCache
148
148
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
149
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
149
+ from sglang.srt.model_executor.forward_batch_info import (
150
+ ForwardBatchOutput,
151
+ ForwardMode,
152
+ PPProxyTensors,
153
+ )
150
154
  from sglang.srt.parser.reasoning_parser import ReasoningParser
151
155
  from sglang.srt.server_args import PortArgs, ServerArgs
152
156
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
153
157
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
154
158
  from sglang.srt.tracing.trace import (
155
159
  process_tracing_init,
156
- trace_event,
157
160
  trace_set_proc_propagate_context,
158
161
  trace_set_thread_info,
159
- trace_slice,
162
+ trace_slice_batch,
160
163
  trace_slice_end,
161
164
  trace_slice_start,
162
165
  )
@@ -170,8 +173,8 @@ from sglang.srt.utils import (
170
173
  freeze_gc,
171
174
  get_available_gpu_memory,
172
175
  get_bool_env_var,
176
+ get_int_env_var,
173
177
  get_zmq_socket,
174
- is_cpu,
175
178
  kill_itself_when_parent_died,
176
179
  numa_bind_to_node,
177
180
  point_to_point_pyobj,
@@ -182,6 +185,11 @@ from sglang.srt.utils import (
182
185
  set_random_seed,
183
186
  suppress_other_loggers,
184
187
  )
188
+ from sglang.srt.utils.hf_transformers_utils import (
189
+ get_processor,
190
+ get_tokenizer,
191
+ get_tokenizer_from_processor,
192
+ )
185
193
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
186
194
 
187
195
  logger = logging.getLogger(__name__)
@@ -190,24 +198,59 @@ logger = logging.getLogger(__name__)
190
198
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
191
199
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
192
200
 
193
- _is_cpu = is_cpu()
194
-
195
201
 
196
202
  @dataclass
197
203
  class GenerationBatchResult:
198
204
  logits_output: Optional[LogitsProcessorOutput]
199
- pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
205
+ pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
200
206
  next_token_ids: Optional[List[int]]
207
+ can_run_cuda_graph: bool
208
+
209
+ # For output processing
201
210
  extend_input_len_per_req: List[int]
202
211
  extend_logprob_start_len_per_req: List[int]
203
- bid: int
204
- can_run_cuda_graph: bool
212
+
213
+ @classmethod
214
+ def from_forward_batch_output(
215
+ cls,
216
+ forward_batch_output: ForwardBatchOutput,
217
+ extend_input_len_per_req: List[int],
218
+ extend_logprob_start_len_per_req: List[int],
219
+ ):
220
+ # TODO(lsyin): remove this workaround logic and try to unify output classes
221
+
222
+ return cls(
223
+ logits_output=forward_batch_output.logits_output,
224
+ pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
225
+ next_token_ids=forward_batch_output.next_token_ids,
226
+ extend_input_len_per_req=extend_input_len_per_req,
227
+ extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
228
+ can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
229
+ )
230
+
231
+ @classmethod
232
+ def from_pp_proxy(
233
+ cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
234
+ ):
235
+ # TODO(lsyin): also simplify this logic
236
+ # Current PP implementation in scheduler is not compatible with ForwardBatchOutput
237
+ # Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
238
+ proxy_dict = next_pp_outputs.tensors
239
+ return cls(
240
+ logits_output=logits_output,
241
+ pp_hidden_states_proxy_tensors=None,
242
+ next_token_ids=next_pp_outputs["next_token_ids"],
243
+ extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
244
+ extend_logprob_start_len_per_req=proxy_dict.get(
245
+ "extend_logprob_start_len_per_req", None
246
+ ),
247
+ can_run_cuda_graph=can_run_cuda_graph,
248
+ )
205
249
 
206
250
 
207
251
  @dataclass
208
252
  class EmbeddingBatchResult:
209
253
  embeddings: torch.Tensor
210
- bid: int
211
254
 
212
255
 
213
256
  class Scheduler(
@@ -229,7 +272,6 @@ class Scheduler(
229
272
  moe_ep_rank: int,
230
273
  pp_rank: int,
231
274
  dp_rank: Optional[int],
232
- dp_balance_meta: Optional[DPBalanceMeta] = None,
233
275
  ):
234
276
  # Parse args
235
277
  self.server_args = server_args
@@ -242,6 +284,13 @@ class Scheduler(
242
284
  self.pp_size = server_args.pp_size
243
285
  self.dp_size = server_args.dp_size
244
286
  self.schedule_policy = server_args.schedule_policy
287
+ self.enable_priority_scheduling = server_args.enable_priority_scheduling
288
+ self.schedule_low_priority_values_first = (
289
+ server_args.schedule_low_priority_values_first
290
+ )
291
+ self.priority_scheduling_preemption_threshold = (
292
+ server_args.priority_scheduling_preemption_threshold
293
+ )
245
294
  self.enable_lora = server_args.enable_lora
246
295
  self.max_loras_per_batch = server_args.max_loras_per_batch
247
296
  self.enable_overlap = not server_args.disable_overlap_schedule
@@ -250,7 +299,10 @@ class Scheduler(
250
299
  self.enable_metrics_for_all_schedulers = (
251
300
  server_args.enable_metrics_for_all_schedulers
252
301
  )
253
- self.enable_kv_cache_events = server_args.kv_events_config is not None
302
+ self.enable_kv_cache_events = bool(
303
+ server_args.kv_events_config and tp_rank == 0
304
+ )
305
+ self.enable_trace = server_args.enable_trace
254
306
  self.stream_interval = server_args.stream_interval
255
307
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
256
308
  server_args.speculative_algorithm
@@ -376,9 +428,27 @@ class Scheduler(
376
428
  target_worker=self.tp_worker,
377
429
  dp_rank=dp_rank,
378
430
  )
431
+ elif self.spec_algorithm.is_ngram():
432
+ from sglang.srt.speculative.ngram_worker import NGRAMWorker
433
+
434
+ self.draft_worker = NGRAMWorker(
435
+ gpu_id=gpu_id,
436
+ tp_rank=tp_rank,
437
+ moe_ep_rank=moe_ep_rank,
438
+ server_args=server_args,
439
+ nccl_port=port_args.nccl_port,
440
+ target_worker=self.tp_worker,
441
+ dp_rank=dp_rank,
442
+ )
379
443
  else:
380
444
  self.draft_worker = None
381
445
 
446
+ # Dispatch the model worker
447
+ if self.spec_algorithm.is_none():
448
+ self.model_worker = self.tp_worker
449
+ else:
450
+ self.model_worker = self.draft_worker
451
+
382
452
  # Get token and memory info from the model worker
383
453
  (
384
454
  self.max_total_num_tokens,
@@ -486,7 +556,12 @@ class Scheduler(
486
556
  self.schedule_policy,
487
557
  self.tree_cache,
488
558
  self.enable_hierarchical_cache,
559
+ self.enable_priority_scheduling,
560
+ self.schedule_low_priority_values_first,
489
561
  )
562
+ # Enable preemption for priority scheduling.
563
+ self.try_preemption = self.enable_priority_scheduling
564
+
490
565
  assert (
491
566
  server_args.schedule_conservativeness >= 0
492
567
  ), "Invalid schedule_conservativeness"
@@ -527,8 +602,9 @@ class Scheduler(
527
602
 
528
603
  # Init metrics stats
529
604
  self.init_metrics(tp_rank, pp_rank, dp_rank)
530
- self.init_kv_events(server_args.kv_events_config)
531
- self.init_dp_balance(dp_balance_meta)
605
+
606
+ if self.enable_kv_cache_events:
607
+ self.init_kv_events(server_args.kv_events_config)
532
608
 
533
609
  # Init disaggregation
534
610
  self.disaggregation_mode = DisaggregationMode(
@@ -539,6 +615,9 @@ class Scheduler(
539
615
  if get_bool_env_var("SGLANG_GC_LOG"):
540
616
  configure_gc_logger()
541
617
 
618
+ # Init prefill kv split size when deterministic inference is enabled with various attention backends
619
+ self.init_deterministic_inference_config()
620
+
542
621
  # Init request dispatcher
543
622
  self._request_dispatcher = TypeBasedDispatcher(
544
623
  [
@@ -553,6 +632,7 @@ class Scheduler(
553
632
  (CloseSessionReqInput, self.close_session),
554
633
  (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
555
634
  (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
635
+ (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
556
636
  (
557
637
  InitWeightsSendGroupForRemoteInstanceReqInput,
558
638
  self.init_weights_send_group_for_remote_instance,
@@ -583,6 +663,23 @@ class Scheduler(
583
663
  ]
584
664
  )
585
665
 
666
+ def init_deterministic_inference_config(self):
667
+ """Initialize deterministic inference configuration for different attention backends."""
668
+ if not self.server_args.enable_deterministic_inference:
669
+ self.truncation_align_size = None
670
+ return
671
+
672
+ backend_sizes = {
673
+ "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
674
+ "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
675
+ }
676
+ env_var, default_size = backend_sizes.get(
677
+ self.server_args.attention_backend, (None, None)
678
+ )
679
+ self.truncation_align_size = (
680
+ get_int_env_var(env_var, default_size) if env_var else None
681
+ )
682
+
586
683
  def init_tokenizer(self):
587
684
  server_args = self.server_args
588
685
  self.is_generation = self.model_config.is_generation
@@ -654,6 +751,7 @@ class Scheduler(
654
751
  else self.tp_cpu_group
655
752
  ),
656
753
  page_size=self.page_size,
754
+ eviction_policy=server_args.radix_eviction_policy,
657
755
  hicache_ratio=server_args.hicache_ratio,
658
756
  hicache_size=server_args.hicache_size,
659
757
  hicache_write_policy=server_args.hicache_write_policy,
@@ -664,6 +762,7 @@ class Scheduler(
664
762
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
665
763
  model_name=server_args.served_model_name,
666
764
  storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
765
+ is_eagle=self.spec_algorithm.is_eagle(),
667
766
  )
668
767
  self.tp_worker.register_hicache_layer_transfer_counter(
669
768
  self.tree_cache.cache_controller.layer_done_counter
@@ -679,19 +778,6 @@ class Scheduler(
679
778
  page_size=self.page_size,
680
779
  disable=server_args.disable_radix_cache,
681
780
  )
682
- elif self.enable_lora:
683
- assert (
684
- not self.enable_hierarchical_cache
685
- ), "LoRA radix cache doesn't support hierarchical cache"
686
- assert (
687
- self.schedule_policy == "fcfs"
688
- ), "LoRA radix cache only supports FCFS policy"
689
- self.tree_cache = LoRARadixCache(
690
- req_to_token_pool=self.req_to_token_pool,
691
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
692
- page_size=self.page_size,
693
- disable=server_args.disable_radix_cache,
694
- )
695
781
  elif server_args.enable_lmcache:
696
782
  from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
697
783
  LMCRadixCache,
@@ -706,6 +792,7 @@ class Scheduler(
706
792
  tp_size=self.tp_size,
707
793
  rank=self.tp_rank,
708
794
  tp_group=self.tp_group,
795
+ eviction_policy=server_args.radix_eviction_policy,
709
796
  )
710
797
  else:
711
798
  self.tree_cache = RadixCache(
@@ -714,16 +801,36 @@ class Scheduler(
714
801
  page_size=self.page_size,
715
802
  disable=server_args.disable_radix_cache,
716
803
  enable_kv_cache_events=self.enable_kv_cache_events,
804
+ eviction_policy=server_args.radix_eviction_policy,
805
+ is_eagle=self.spec_algorithm.is_eagle(),
717
806
  )
718
807
 
808
+ if (
809
+ server_args.disaggregation_mode == "decode"
810
+ and server_args.disaggregation_decode_enable_offload_kvcache
811
+ ):
812
+ self.decode_offload_manager = DecodeKVCacheOffloadManager(
813
+ req_to_token_pool=self.req_to_token_pool,
814
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
815
+ tp_group=(
816
+ self.attn_tp_cpu_group
817
+ if self.server_args.enable_dp_attention
818
+ else self.tp_cpu_group
819
+ ),
820
+ tree_cache=self.tree_cache,
821
+ server_args=self.server_args,
822
+ )
823
+ else:
824
+ self.decode_offload_manager = None
825
+
719
826
  self.decode_mem_cache_buf_multiplier = (
720
827
  1
721
828
  if self.spec_algorithm.is_none()
722
829
  else (
723
830
  server_args.speculative_num_draft_tokens
724
831
  + (
725
- server_args.speculative_eagle_topk
726
- * server_args.speculative_num_steps
832
+ (server_args.speculative_eagle_topk or 1)
833
+ * (server_args.speculative_num_steps or 1)
727
834
  )
728
835
  )
729
836
  )
@@ -746,7 +853,7 @@ class Scheduler(
746
853
  self.disagg_metadata_buffers = MetadataBuffers(
747
854
  buffer_size,
748
855
  hidden_size=self.model_config.hf_text_config.hidden_size,
749
- dtype=self.model_config.dtype,
856
+ hidden_states_dtype=self.model_config.dtype,
750
857
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
751
858
  )
752
859
 
@@ -766,7 +873,7 @@ class Scheduler(
766
873
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
767
874
  draft_token_to_kv_pool=(
768
875
  None
769
- if self.draft_worker is None
876
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
770
877
  else self.draft_worker.model_runner.token_to_kv_pool
771
878
  ),
772
879
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -795,7 +902,7 @@ class Scheduler(
795
902
  self.disagg_metadata_buffers = MetadataBuffers(
796
903
  buffer_size,
797
904
  hidden_size=self.model_config.hf_text_config.hidden_size,
798
- dtype=self.model_config.dtype,
905
+ hidden_states_dtype=self.model_config.dtype,
799
906
  custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
800
907
  )
801
908
 
@@ -803,7 +910,7 @@ class Scheduler(
803
910
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
804
911
  draft_token_to_kv_pool=(
805
912
  None
806
- if self.draft_worker is None
913
+ if self.draft_worker is None or self.spec_algorithm.is_ngram()
807
914
  else self.draft_worker.model_runner.token_to_kv_pool
808
915
  ),
809
916
  req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -838,10 +945,6 @@ class Scheduler(
838
945
  batch = self.get_next_batch_to_run()
839
946
  self.cur_batch = batch
840
947
 
841
- if batch:
842
- for req in batch.reqs:
843
- trace_event("schedule", req.rid)
844
-
845
948
  if batch:
846
949
  result = self.run_batch(batch)
847
950
  self.process_batch_result(batch, result)
@@ -863,10 +966,6 @@ class Scheduler(
863
966
  batch = self.get_next_batch_to_run()
864
967
  self.cur_batch = batch
865
968
 
866
- if batch:
867
- for req in batch.reqs:
868
- trace_event("schedule", req.rid)
869
-
870
969
  if batch:
871
970
  batch.launch_done = threading.Event()
872
971
  result = self.run_batch(batch)
@@ -906,7 +1005,6 @@ class Scheduler(
906
1005
  self.running_mbs = [
907
1006
  ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
908
1007
  ]
909
- bids = [None] * self.pp_size
910
1008
  pp_outputs: Optional[PPProxyTensors] = None
911
1009
  while True:
912
1010
  server_is_idle = True
@@ -927,10 +1025,7 @@ class Scheduler(
927
1025
  # (last rank) send the outputs to the next step
928
1026
  if self.pp_group.is_last_rank:
929
1027
  if self.cur_batch:
930
- next_token_ids, bids[mb_id] = (
931
- result.next_token_ids,
932
- result.bid,
933
- )
1028
+ next_token_ids = result.next_token_ids
934
1029
  if self.cur_batch.return_logprob:
935
1030
  pp_outputs = PPProxyTensors(
936
1031
  {
@@ -978,17 +1073,10 @@ class Scheduler(
978
1073
  logits_output = LogitsProcessorOutput(**logits_output_args)
979
1074
  else:
980
1075
  logits_output = None
981
- output_result = GenerationBatchResult(
1076
+
1077
+ output_result = GenerationBatchResult.from_pp_proxy(
982
1078
  logits_output=logits_output,
983
- pp_hidden_states_proxy_tensors=None,
984
- next_token_ids=next_pp_outputs["next_token_ids"],
985
- extend_input_len_per_req=next_pp_outputs.tensors.get(
986
- "extend_input_len_per_req", None
987
- ),
988
- extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
989
- "extend_logprob_start_len_per_req", None
990
- ),
991
- bid=bids[next_mb_id],
1079
+ next_pp_outputs=next_pp_outputs,
992
1080
  can_run_cuda_graph=result.can_run_cuda_graph,
993
1081
  )
994
1082
  self.process_batch_result(mbs[next_mb_id], output_result)
@@ -996,8 +1084,6 @@ class Scheduler(
996
1084
 
997
1085
  # (not last rank)
998
1086
  if not self.pp_group.is_last_rank:
999
- if self.cur_batch:
1000
- bids[mb_id] = result.bid
1001
1087
  # carry the outputs to the next stage
1002
1088
  # send the outputs from the last round to let the next stage worker run post processing
1003
1089
  if pp_outputs:
@@ -1019,8 +1105,10 @@ class Scheduler(
1019
1105
 
1020
1106
  # send out proxy tensors to the next stage
1021
1107
  if self.cur_batch:
1108
+ # FIXME(lsyin): remove this assert
1109
+ assert result.pp_hidden_states_proxy_tensors.tensors is not None
1022
1110
  self.pp_group.send_tensor_dict(
1023
- result.pp_hidden_states_proxy_tensors,
1111
+ result.pp_hidden_states_proxy_tensors.tensors,
1024
1112
  all_gather_group=self.attn_tp_group,
1025
1113
  )
1026
1114
 
@@ -1131,10 +1219,13 @@ class Scheduler(
1131
1219
  src=self.tp_group.ranks[0],
1132
1220
  )
1133
1221
 
1134
- for req in recv_reqs:
1135
- if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
1136
- trace_set_proc_propagate_context(req.rid, req.trace_context)
1137
- trace_slice_start("", req.rid, anonymous=True)
1222
+ if self.enable_trace:
1223
+ for req in recv_reqs:
1224
+ if isinstance(
1225
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1226
+ ):
1227
+ trace_set_proc_propagate_context(req.rid, req.trace_context)
1228
+ trace_slice_start("", req.rid, anonymous=True)
1138
1229
 
1139
1230
  return recv_reqs
1140
1231
 
@@ -1149,20 +1240,6 @@ class Scheduler(
1149
1240
  self.return_health_check_ct += 1
1150
1241
  continue
1151
1242
 
1152
- # If it is a work request, accept or reject the request based on the request queue size.
1153
- if is_work_request(recv_req):
1154
- if len(self.waiting_queue) + 1 > self.max_queued_requests:
1155
- abort_req = AbortReq(
1156
- recv_req.rid,
1157
- finished_reason={
1158
- "type": "abort",
1159
- "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1160
- "message": "The request queue is full.",
1161
- },
1162
- )
1163
- self.send_to_tokenizer.send_pyobj(abort_req)
1164
- continue
1165
-
1166
1243
  # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
1167
1244
  if isinstance(recv_req, MultiTokenizerWrapper):
1168
1245
  worker_id = recv_req.worker_id
@@ -1195,8 +1272,6 @@ class Scheduler(
1195
1272
  self,
1196
1273
  recv_req: TokenizedGenerateReqInput,
1197
1274
  ):
1198
- self.maybe_update_dp_balance_data(recv_req)
1199
-
1200
1275
  # Create a new request
1201
1276
  if (
1202
1277
  recv_req.session_params is None
@@ -1230,8 +1305,13 @@ class Scheduler(
1230
1305
  bootstrap_host=recv_req.bootstrap_host,
1231
1306
  bootstrap_port=recv_req.bootstrap_port,
1232
1307
  bootstrap_room=recv_req.bootstrap_room,
1308
+ disagg_mode=self.disaggregation_mode,
1233
1309
  data_parallel_rank=recv_req.data_parallel_rank,
1234
1310
  vocab_size=self.model_config.vocab_size,
1311
+ priority=recv_req.priority,
1312
+ metrics_collector=(
1313
+ self.metrics_collector if self.enable_metrics else None
1314
+ ),
1235
1315
  )
1236
1316
  req.tokenizer = self.tokenizer
1237
1317
 
@@ -1352,7 +1432,6 @@ class Scheduler(
1352
1432
  req.set_finish_with_abort(error_msg)
1353
1433
 
1354
1434
  if add_to_grammar_queue:
1355
- req.queue_time_start = time.perf_counter()
1356
1435
  self.grammar_queue.append(req)
1357
1436
  else:
1358
1437
  self._add_request_to_queue(req)
@@ -1368,20 +1447,6 @@ class Scheduler(
1368
1447
  for tokenized_req in recv_req:
1369
1448
  self.handle_generate_request(tokenized_req)
1370
1449
 
1371
- def _add_request_to_queue(self, req: Req):
1372
- req.queue_time_start = time.perf_counter()
1373
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1374
- self._prefetch_kvcache(req)
1375
- self.disagg_prefill_bootstrap_queue.add(
1376
- req, self.model_config.num_key_value_heads
1377
- )
1378
- elif self.disaggregation_mode == DisaggregationMode.DECODE:
1379
- self.disagg_decode_prealloc_queue.add(req)
1380
- else:
1381
- self._prefetch_kvcache(req)
1382
- self.waiting_queue.append(req)
1383
- trace_slice_end("process req", req.rid, auto_next_anon=True)
1384
-
1385
1450
  def _prefetch_kvcache(self, req: Req):
1386
1451
  if self.enable_hicache_storage:
1387
1452
  req.init_next_round_input(self.tree_cache)
@@ -1395,16 +1460,87 @@ class Scheduler(
1395
1460
  req.rid, req.last_host_node, new_input_tokens, last_hash
1396
1461
  )
1397
1462
 
1398
- def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1399
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1400
- self.disagg_prefill_bootstrap_queue.extend(
1401
- reqs, self.model_config.num_key_value_heads
1463
+ def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
1464
+ if self.disaggregation_mode == DisaggregationMode.NULL:
1465
+ self._set_or_validate_priority(req)
1466
+ if self._abort_on_queued_limit(req):
1467
+ return
1468
+ self._prefetch_kvcache(req)
1469
+ self.waiting_queue.append(req)
1470
+ req.time_stats.wait_queue_entry_time = time.perf_counter()
1471
+ trace_slice_end("process req", req.rid, auto_next_anon=True)
1472
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
1473
+ self._prefetch_kvcache(req)
1474
+ self.disagg_prefill_bootstrap_queue.add(
1475
+ req, self.model_config.num_key_value_heads
1402
1476
  )
1477
+ req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
1403
1478
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
1404
- # If this is a decode server, we put the request to the decode pending prealloc queue
1405
- self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
1479
+ self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
1480
+ if not is_retracted:
1481
+ req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
1406
1482
  else:
1407
- self.waiting_queue.extend(reqs)
1483
+ raise ValueError(f"Invalid {self.disaggregation_mode=}")
1484
+
1485
+ def _set_or_validate_priority(self, req: Req):
1486
+ """Set the default priority value, or abort the request based on the priority scheduling mode."""
1487
+ if self.enable_priority_scheduling and req.priority is None:
1488
+ if self.schedule_low_priority_values_first:
1489
+ req.priority = sys.maxsize
1490
+ else:
1491
+ req.priority = -sys.maxsize - 1
1492
+ elif not self.enable_priority_scheduling and req.priority is not None:
1493
+ abort_req = AbortReq(
1494
+ finished_reason={
1495
+ "type": "abort",
1496
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1497
+ "message": "Using priority is disabled for this server. Please send a new request without a priority.",
1498
+ },
1499
+ rid=req.rid,
1500
+ )
1501
+ self.send_to_tokenizer.send_pyobj(abort_req)
1502
+
1503
+ def _abort_on_queued_limit(self, recv_req: Req) -> bool:
1504
+ """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
1505
+ if (
1506
+ self.max_queued_requests is None
1507
+ or len(self.waiting_queue) + 1 <= self.max_queued_requests
1508
+ ):
1509
+ return False
1510
+
1511
+ # Reject the incoming request by default.
1512
+ req_to_abort = recv_req
1513
+ message = "The request queue is full."
1514
+ if self.enable_priority_scheduling:
1515
+ # With priority scheduling, consider aboritng an existing request based on the priority.
1516
+ # direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
1517
+ # max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
1518
+ # Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
1519
+ direction = 1 if self.schedule_low_priority_values_first else -1
1520
+ key_fn = lambda item: (
1521
+ direction * item[1].priority,
1522
+ item[1].time_stats.wait_queue_entry_time,
1523
+ )
1524
+ idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
1525
+ abort_existing_req = (
1526
+ direction * recv_req.priority < direction * candidate_req.priority
1527
+ )
1528
+ if abort_existing_req:
1529
+ self.waiting_queue.pop(idx)
1530
+ req_to_abort = candidate_req
1531
+ message = "The request is aborted by a higher priority request."
1532
+
1533
+ self.send_to_tokenizer.send_pyobj(
1534
+ AbortReq(
1535
+ finished_reason={
1536
+ "type": "abort",
1537
+ "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1538
+ "message": message,
1539
+ },
1540
+ rid=req_to_abort.rid,
1541
+ )
1542
+ )
1543
+ return req_to_abort.rid == recv_req.rid
1408
1544
 
1409
1545
  def handle_embedding_request(
1410
1546
  self,
@@ -1416,6 +1552,7 @@ class Scheduler(
1416
1552
  recv_req.input_ids,
1417
1553
  recv_req.sampling_params,
1418
1554
  token_type_ids=recv_req.token_type_ids,
1555
+ priority=recv_req.priority,
1419
1556
  )
1420
1557
  req.tokenizer = self.tokenizer
1421
1558
 
@@ -1660,7 +1797,6 @@ class Scheduler(
1660
1797
 
1661
1798
  # Handle DP attention
1662
1799
  if need_dp_attn_preparation:
1663
- self.maybe_handle_dp_balance_data()
1664
1800
  ret = self.prepare_mlp_sync_batch(ret)
1665
1801
 
1666
1802
  return ret
@@ -1676,6 +1812,10 @@ class Scheduler(
1676
1812
  if self.grammar_queue:
1677
1813
  self.move_ready_grammar_requests()
1678
1814
 
1815
+ if self.try_preemption:
1816
+ # Reset batch_is_full to try preemption with a prefill adder.
1817
+ self.running_batch.batch_is_full = False
1818
+
1679
1819
  # Handle the cases where prefill is not allowed
1680
1820
  if (
1681
1821
  self.running_batch.batch_is_full or len(self.waiting_queue) == 0
@@ -1688,7 +1828,11 @@ class Scheduler(
1688
1828
  # as the space for the chunked request has just been released.
1689
1829
  # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
1690
1830
  # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1691
- if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
1831
+ if (
1832
+ self.get_num_allocatable_reqs(running_bs) <= 0
1833
+ and not self.chunked_req
1834
+ and not self.try_preemption
1835
+ ):
1692
1836
  self.running_batch.batch_is_full = True
1693
1837
  return None
1694
1838
 
@@ -1708,6 +1852,7 @@ class Scheduler(
1708
1852
  self.max_prefill_tokens,
1709
1853
  self.chunked_prefill_size,
1710
1854
  running_bs if self.is_mixed_chunk else 0,
1855
+ self.priority_scheduling_preemption_threshold,
1711
1856
  )
1712
1857
 
1713
1858
  if self.chunked_req is not None:
@@ -1728,15 +1873,19 @@ class Scheduler(
1728
1873
  self.running_batch.batch_is_full = True
1729
1874
  break
1730
1875
 
1876
+ running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
1731
1877
  if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
1732
1878
  self.running_batch.batch_is_full = True
1733
- break
1734
-
1735
1879
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1736
1880
  # In prefill mode, prealloc queue and transfer queue can also take memory,
1737
1881
  # so we need to check if the available size for the actual available size.
1738
1882
  if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
1739
1883
  self.running_batch.batch_is_full = True
1884
+
1885
+ if self.running_batch.batch_is_full:
1886
+ if not self.try_preemption:
1887
+ break
1888
+ if not adder.preempt_to_schedule(req, self.server_args):
1740
1889
  break
1741
1890
 
1742
1891
  if self.enable_hicache_storage:
@@ -1746,7 +1895,11 @@ class Scheduler(
1746
1895
  continue
1747
1896
 
1748
1897
  req.init_next_round_input(self.tree_cache)
1749
- res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1898
+ res = adder.add_one_req(
1899
+ req,
1900
+ has_chunked_req=(self.chunked_req is not None),
1901
+ truncation_align_size=self.truncation_align_size,
1902
+ )
1750
1903
 
1751
1904
  if res != AddReqResult.CONTINUE:
1752
1905
  if res == AddReqResult.NO_TOKEN:
@@ -1767,11 +1920,14 @@ class Scheduler(
1767
1920
  if self.enable_metrics:
1768
1921
  # only record queue time when enable_metrics is True to avoid overhead
1769
1922
  for req in can_run_list:
1770
- req.queue_time_end = time.perf_counter()
1923
+ req.add_latency(RequestStage.PREFILL_WAITING)
1771
1924
 
1772
1925
  self.waiting_queue = [
1773
1926
  x for x in self.waiting_queue if x not in set(can_run_list)
1774
1927
  ]
1928
+ if adder.preempt_list:
1929
+ for req in adder.preempt_list:
1930
+ self._add_request_to_queue(req)
1775
1931
 
1776
1932
  if adder.new_chunked_req is not None:
1777
1933
  assert self.chunked_req is None
@@ -1782,7 +1938,16 @@ class Scheduler(
1782
1938
 
1783
1939
  # Print stats
1784
1940
  if self.current_scheduler_metrics_enabled():
1785
- self.log_prefill_stats(adder, can_run_list, running_bs)
1941
+ self.log_prefill_stats(adder, can_run_list, running_bs, 0)
1942
+
1943
+ for req in can_run_list:
1944
+ if req.time_stats.forward_entry_time == 0:
1945
+ # Avoid update chunked request many times
1946
+ req.time_stats.forward_entry_time = time.perf_counter()
1947
+ if self.enable_metrics:
1948
+ self.metrics_collector.observe_queue_time(
1949
+ req.time_stats.get_queueing_time(),
1950
+ )
1786
1951
 
1787
1952
  # Create a new batch
1788
1953
  new_batch = ScheduleBatch.init_new(
@@ -1837,19 +2002,25 @@ class Scheduler(
1837
2002
  TEST_RETRACT and batch.batch_size() > 10
1838
2003
  ):
1839
2004
  old_ratio = self.new_token_ratio
1840
-
1841
- retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1842
- num_retracted_reqs = len(retracted_reqs)
2005
+ retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
2006
+ self.server_args
2007
+ )
2008
+ self.num_retracted_reqs = len(retracted_reqs)
1843
2009
  self.new_token_ratio = new_token_ratio
2010
+ for req in reqs_to_abort:
2011
+ self.send_to_tokenizer.send_pyobj(
2012
+ AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
2013
+ )
1844
2014
 
1845
2015
  logger.info(
1846
2016
  "KV cache pool is full. Retract requests. "
1847
- f"#retracted_reqs: {num_retracted_reqs}, "
1848
- f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
2017
+ f"#retracted_reqs: {len(retracted_reqs)}, "
2018
+ f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
2019
+ f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
1849
2020
  )
1850
2021
 
1851
- self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1852
- self.total_retracted_reqs += num_retracted_reqs
2022
+ for req in retracted_reqs:
2023
+ self._add_request_to_queue(req, is_retracted=True)
1853
2024
  else:
1854
2025
  self.new_token_ratio = max(
1855
2026
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -1877,33 +2048,25 @@ class Scheduler(
1877
2048
 
1878
2049
  # Run forward
1879
2050
  if self.is_generation:
2051
+
2052
+ batch_or_worker_batch = batch
2053
+
1880
2054
  if self.spec_algorithm.is_none():
1881
- model_worker_batch = batch.get_model_worker_batch()
2055
+ # FIXME(lsyin): remove this if and finally unify the abstraction
2056
+ batch_or_worker_batch = batch.get_model_worker_batch()
1882
2057
 
1883
- if self.pp_group.is_last_rank:
1884
- logits_output, next_token_ids, can_run_cuda_graph = (
1885
- self.tp_worker.forward_batch_generation(model_worker_batch)
1886
- )
1887
- else:
1888
- pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1889
- self.tp_worker.forward_batch_generation(model_worker_batch)
1890
- )
1891
- bid = model_worker_batch.bid
1892
- else:
1893
- (
1894
- logits_output,
1895
- next_token_ids,
1896
- bid,
1897
- num_accepted_tokens,
1898
- can_run_cuda_graph,
1899
- ) = self.draft_worker.forward_batch_speculative_generation(batch)
1900
- bs = batch.batch_size()
1901
- self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
1902
- self.spec_num_total_forward_ct += bs
1903
- self.num_generated_tokens += num_accepted_tokens
1904
-
1905
- if self.pp_group.is_last_rank:
1906
- batch.output_ids = next_token_ids
2058
+ forward_batch_output = self.model_worker.forward_batch_generation(
2059
+ batch_or_worker_batch
2060
+ )
2061
+
2062
+ if not self.spec_algorithm.is_none():
2063
+ # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
2064
+ self.udpate_spec_metrics(
2065
+ batch.batch_size(), forward_batch_output.num_accepted_tokens
2066
+ )
2067
+
2068
+ # update batch's output ids
2069
+ batch.output_ids = forward_batch_output.next_token_ids
1907
2070
 
1908
2071
  # These 2 values are needed for processing the output, but the values can be
1909
2072
  # modified by overlap schedule. So we have to copy them here so that
@@ -1912,6 +2075,7 @@ class Scheduler(
1912
2075
  extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1913
2076
  else:
1914
2077
  extend_input_len_per_req = None
2078
+
1915
2079
  if batch.return_logprob:
1916
2080
  extend_logprob_start_len_per_req = [
1917
2081
  req.extend_logprob_start_len for req in batch.reqs
@@ -1919,25 +2083,15 @@ class Scheduler(
1919
2083
  else:
1920
2084
  extend_logprob_start_len_per_req = None
1921
2085
 
1922
- ret = GenerationBatchResult(
1923
- logits_output=logits_output if self.pp_group.is_last_rank else None,
1924
- pp_hidden_states_proxy_tensors=(
1925
- pp_hidden_states_proxy_tensors
1926
- if not self.pp_group.is_last_rank
1927
- else None
1928
- ),
1929
- next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
2086
+ return GenerationBatchResult.from_forward_batch_output(
2087
+ forward_batch_output=forward_batch_output,
1930
2088
  extend_input_len_per_req=extend_input_len_per_req,
1931
2089
  extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1932
- bid=bid,
1933
- can_run_cuda_graph=can_run_cuda_graph,
1934
2090
  )
1935
2091
  else: # embedding or reward model
1936
2092
  model_worker_batch = batch.get_model_worker_batch()
1937
2093
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1938
- ret = EmbeddingBatchResult(
1939
- embeddings=embeddings, bid=model_worker_batch.bid
1940
- )
2094
+ ret = EmbeddingBatchResult(embeddings=embeddings)
1941
2095
  return ret
1942
2096
 
1943
2097
  def process_batch_result(
@@ -1948,23 +2102,14 @@ class Scheduler(
1948
2102
  ):
1949
2103
  if batch.forward_mode.is_decode():
1950
2104
  self.process_batch_result_decode(batch, result, launch_done)
1951
- for req in batch.reqs:
1952
- trace_slice(
1953
- "decode loop",
1954
- req.rid,
1955
- auto_next_anon=not req.finished(),
1956
- thread_finish_flag=req.finished(),
1957
- )
2105
+ if self.enable_trace:
2106
+ trace_slice_batch("decode loop", batch.reqs)
1958
2107
 
1959
2108
  elif batch.forward_mode.is_extend():
1960
2109
  self.process_batch_result_prefill(batch, result, launch_done)
1961
- for req in batch.reqs:
1962
- trace_slice(
1963
- "prefill",
1964
- req.rid,
1965
- auto_next_anon=not req.finished(),
1966
- thread_finish_flag=req.finished(),
1967
- )
2110
+ if self.enable_trace:
2111
+ trace_slice_batch("prefill", batch.reqs)
2112
+
1968
2113
  elif batch.forward_mode.is_idle():
1969
2114
  if self.enable_overlap:
1970
2115
  self.tp_worker.resolve_last_batch_result(launch_done)
@@ -2123,12 +2268,13 @@ class Scheduler(
2123
2268
  if req.finished(): # It is aborted by AbortReq
2124
2269
  num_ready_reqs += 1
2125
2270
  continue
2271
+
2126
2272
  req.grammar = req.grammar.result(timeout=0.03)
2127
2273
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2128
2274
  if req.grammar is INVALID_GRAMMAR_OBJ:
2129
- req.set_finish_with_abort(
2130
- f"Invalid grammar request: {req.grammar_key=}"
2131
- )
2275
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2276
+ req.set_finish_with_abort(error_msg)
2277
+
2132
2278
  num_ready_reqs += 1
2133
2279
  except futures._base.TimeoutError:
2134
2280
  req.grammar_wait_ct += 1
@@ -2160,9 +2306,8 @@ class Scheduler(
2160
2306
  req.grammar = req.grammar.result()
2161
2307
  self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
2162
2308
  if req.grammar is INVALID_GRAMMAR_OBJ:
2163
- req.set_finish_with_abort(
2164
- f"Invalid grammar request: {req.grammar_key=}"
2165
- )
2309
+ error_msg = f"Invalid grammar request: {req.grammar_key=}"
2310
+ req.set_finish_with_abort(error_msg)
2166
2311
  else:
2167
2312
  num_ready_reqs_max = num_ready_reqs
2168
2313
  num_timeout_reqs_max = num_timeout_reqs
@@ -2170,12 +2315,14 @@ class Scheduler(
2170
2315
  for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
2171
2316
  req = self.grammar_queue[i]
2172
2317
  req.grammar.cancel()
2318
+ self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2173
2319
  error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
2174
2320
  req.set_finish_with_abort(error_msg)
2175
- self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2321
+
2176
2322
  num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2177
2323
 
2178
- self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2324
+ for req in self.grammar_queue[:num_ready_reqs]:
2325
+ self._add_request_to_queue(req)
2179
2326
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
2180
2327
 
2181
2328
  def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
@@ -2267,9 +2414,8 @@ class Scheduler(
2267
2414
  self.req_to_token_pool.clear()
2268
2415
  self.token_to_kv_pool_allocator.clear()
2269
2416
 
2270
- if not self.spec_algorithm.is_none():
2271
- self.draft_worker.model_runner.req_to_token_pool.clear()
2272
- self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2417
+ if self.draft_worker:
2418
+ self.draft_worker.clear_cache_pool()
2273
2419
 
2274
2420
  self.num_generated_tokens = 0
2275
2421
  self.forward_ct_decode = 0
@@ -2433,7 +2579,7 @@ class Scheduler(
2433
2579
  if self.enable_hicache_storage:
2434
2580
  # to release prefetch events associated with the request
2435
2581
  self.tree_cache.release_aborted_request(req.rid)
2436
- self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2582
+ self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
2437
2583
  # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2438
2584
  if self.disaggregation_mode == DisaggregationMode.DECODE:
2439
2585
  self.tree_cache.cache_finished_req(req)
@@ -2454,31 +2600,31 @@ class Scheduler(
2454
2600
  # Delete requests not in the waiting queue when PD disaggregation is enabled
2455
2601
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2456
2602
  # Abort requests that have not yet been bootstrapped
2457
- for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
2458
- logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2603
+ for req in self.disagg_prefill_bootstrap_queue.queue:
2459
2604
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2605
+ logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2460
2606
  if hasattr(req.disagg_kv_sender, "abort"):
2461
2607
  req.disagg_kv_sender.abort()
2462
2608
 
2463
2609
  # Abort in-flight requests
2464
- for i, req in enumerate(self.disagg_prefill_inflight_queue):
2465
- logger.debug(f"Abort inflight queue request. {req.rid=}")
2610
+ for req in self.disagg_prefill_inflight_queue:
2466
2611
  if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2612
+ logger.debug(f"Abort inflight queue request. {req.rid=}")
2467
2613
  if hasattr(req.disagg_kv_sender, "abort"):
2468
2614
  req.disagg_kv_sender.abort()
2469
2615
 
2470
2616
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
2471
2617
  # Abort requests that have not yet finished preallocation
2472
- for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
2473
- logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2618
+ for decode_req in self.disagg_decode_prealloc_queue.queue:
2474
2619
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2620
+ logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2475
2621
  if hasattr(decode_req.kv_receiver, "abort"):
2476
2622
  decode_req.kv_receiver.abort()
2477
2623
 
2478
2624
  # Abort requests waiting for kvcache to release tree cache
2479
- for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
2480
- logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2625
+ for decode_req in self.disagg_decode_transfer_queue.queue:
2481
2626
  if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2627
+ logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2482
2628
  if hasattr(decode_req.kv_receiver, "abort"):
2483
2629
  decode_req.kv_receiver.abort()
2484
2630
 
@@ -2545,11 +2691,12 @@ class Scheduler(
2545
2691
  return SlowDownReqOutput()
2546
2692
 
2547
2693
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2548
- if recv_req == ExpertDistributionReq.START_RECORD:
2694
+ action = recv_req.action
2695
+ if action == ExpertDistributionReqType.START_RECORD:
2549
2696
  get_global_expert_distribution_recorder().start_record()
2550
- elif recv_req == ExpertDistributionReq.STOP_RECORD:
2697
+ elif action == ExpertDistributionReqType.STOP_RECORD:
2551
2698
  get_global_expert_distribution_recorder().stop_record()
2552
- elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2699
+ elif action == ExpertDistributionReqType.DUMP_RECORD:
2553
2700
  get_global_expert_distribution_recorder().dump_record()
2554
2701
  else:
2555
2702
  raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
@@ -2632,7 +2779,8 @@ class IdleSleeper:
2632
2779
 
2633
2780
 
2634
2781
  def is_health_check_generate_req(recv_req):
2635
- return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2782
+ rid = getattr(recv_req, "rid", None)
2783
+ return rid is not None and rid.startswith("HEALTH_CHECK")
2636
2784
 
2637
2785
 
2638
2786
  def is_work_request(recv_req):
@@ -2656,19 +2804,12 @@ def run_scheduler_process(
2656
2804
  pp_rank: int,
2657
2805
  dp_rank: Optional[int],
2658
2806
  pipe_writer,
2659
- balance_meta: Optional[DPBalanceMeta] = None,
2660
2807
  ):
2661
- if server_args.enable_trace:
2662
- process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2663
- if server_args.disaggregation_mode == "null":
2664
- thread_label = "Scheduler"
2665
- trace_set_thread_info(thread_label, tp_rank, dp_rank)
2666
-
2667
- if (numa_node := server_args.numa_node) is not None:
2668
- numa_bind_to_node(numa_node[gpu_id])
2669
-
2670
- # Generate the prefix
2808
+ # Generate the logger prefix
2671
2809
  prefix = ""
2810
+ if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2811
+ # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2812
+ dp_rank = int(os.environ["SGLANG_DP_RANK"])
2672
2813
  if dp_rank is not None:
2673
2814
  prefix += f" DP{dp_rank}"
2674
2815
  if server_args.tp_size > 1:
@@ -2684,10 +2825,6 @@ def run_scheduler_process(
2684
2825
  kill_itself_when_parent_died()
2685
2826
  parent_process = psutil.Process().parent()
2686
2827
 
2687
- # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
2688
- if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
2689
- dp_rank = int(os.environ["SGLANG_DP_RANK"])
2690
-
2691
2828
  # Configure the logger
2692
2829
  configure_logger(server_args, prefix=prefix)
2693
2830
  suppress_other_loggers()
@@ -2695,6 +2832,15 @@ def run_scheduler_process(
2695
2832
  # Set cpu affinity to this gpu process
2696
2833
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2697
2834
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2835
+ if (numa_node := server_args.numa_node) is not None:
2836
+ numa_bind_to_node(numa_node[gpu_id])
2837
+
2838
+ # Set up tracing
2839
+ if server_args.enable_trace:
2840
+ process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2841
+ if server_args.disaggregation_mode == "null":
2842
+ thread_label = "Scheduler"
2843
+ trace_set_thread_info(thread_label, tp_rank, dp_rank)
2698
2844
 
2699
2845
  # Create a scheduler and run the event loop
2700
2846
  try:
@@ -2706,7 +2852,6 @@ def run_scheduler_process(
2706
2852
  moe_ep_rank,
2707
2853
  pp_rank,
2708
2854
  dp_rank,
2709
- dp_balance_meta=balance_meta,
2710
2855
  )
2711
2856
  pipe_writer.send(
2712
2857
  {