sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -24,16 +24,16 @@ from collections import deque
24
24
  from concurrent import futures
25
25
  from dataclasses import dataclass
26
26
  from http import HTTPStatus
27
- from types import SimpleNamespace
28
- from typing import Dict, List, Optional, Tuple, Union
27
+ from typing import Deque, Dict, List, Optional, Tuple, Union
29
28
 
30
29
  import psutil
31
30
  import setproctitle
32
31
  import torch
33
32
  import zmq
33
+ from torch.cuda import Stream as CudaStream
34
+ from torch.cuda import StreamContext as CudaStreamContext
34
35
  from torch.distributed import barrier
35
36
 
36
- from sglang.global_config import global_config
37
37
  from sglang.srt.configs.model_config import ModelConfig
38
38
  from sglang.srt.constrained.base_grammar_backend import (
39
39
  INVALID_GRAMMAR_OBJ,
@@ -59,12 +59,14 @@ from sglang.srt.disaggregation.utils import (
59
59
  prepare_abort,
60
60
  )
61
61
  from sglang.srt.distributed import get_pp_group, get_world_group
62
+ from sglang.srt.environ import envs
62
63
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
63
64
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
64
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
65
65
  from sglang.srt.layers.moe import initialize_moe_config
66
66
  from sglang.srt.managers.io_struct import (
67
67
  AbortReq,
68
+ BaseBatchReq,
69
+ BaseReq,
68
70
  BatchTokenizedEmbeddingReqInput,
69
71
  BatchTokenizedGenerateReqInput,
70
72
  ClearHiCacheReqInput,
@@ -88,8 +90,6 @@ from sglang.srt.managers.io_struct import (
88
90
  InitWeightsUpdateGroupReqInput,
89
91
  LoadLoRAAdapterReqInput,
90
92
  LoadLoRAAdapterReqOutput,
91
- MultiTokenizerRegisterReq,
92
- MultiTokenizerWrapper,
93
93
  OpenSessionReqInput,
94
94
  OpenSessionReqOutput,
95
95
  ProfileReq,
@@ -109,16 +109,18 @@ from sglang.srt.managers.io_struct import (
109
109
  UnloadLoRAAdapterReqOutput,
110
110
  UpdateWeightFromDiskReqInput,
111
111
  UpdateWeightsFromDistributedReqInput,
112
+ UpdateWeightsFromIPCReqInput,
112
113
  UpdateWeightsFromTensorReqInput,
113
114
  )
114
115
  from sglang.srt.managers.mm_utils import init_embedding_cache
116
+ from sglang.srt.managers.overlap_utils import FutureMap
115
117
  from sglang.srt.managers.schedule_batch import (
116
118
  FINISH_ABORT,
119
+ ModelWorkerBatch,
117
120
  MultimodalInputs,
118
121
  Req,
119
122
  RequestStage,
120
123
  ScheduleBatch,
121
- global_server_args_dict,
122
124
  )
123
125
  from sglang.srt.managers.schedule_policy import (
124
126
  AddReqResult,
@@ -133,28 +135,25 @@ from sglang.srt.managers.scheduler_metrics_mixin import (
133
135
  from sglang.srt.managers.scheduler_output_processor_mixin import (
134
136
  SchedulerOutputProcessorMixin,
135
137
  )
138
+ from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
136
139
  from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
137
140
  from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
141
+ from sglang.srt.managers.scheduler_runtime_checker_mixin import (
142
+ SchedulerRuntimeCheckerMixin,
143
+ )
138
144
  from sglang.srt.managers.scheduler_update_weights_mixin import (
139
145
  SchedulerUpdateWeightsMixin,
140
146
  )
141
147
  from sglang.srt.managers.session_controller import Session
142
- from sglang.srt.managers.tp_worker import TpModelWorker
143
- from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
144
- from sglang.srt.managers.utils import validate_input_length
148
+ from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length
145
149
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
146
150
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
151
+ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
147
152
  from sglang.srt.mem_cache.radix_cache import RadixCache
148
153
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
149
- from sglang.srt.model_executor.forward_batch_info import (
150
- ForwardBatchOutput,
151
- ForwardMode,
152
- PPProxyTensors,
153
- )
154
154
  from sglang.srt.parser.reasoning_parser import ReasoningParser
155
- from sglang.srt.server_args import PortArgs, ServerArgs
155
+ from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
156
156
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
157
- from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
158
157
  from sglang.srt.tracing.trace import (
159
158
  process_tracing_init,
160
159
  trace_set_proc_propagate_context,
@@ -190,64 +189,17 @@ from sglang.srt.utils.hf_transformers_utils import (
190
189
  get_tokenizer,
191
190
  get_tokenizer_from_processor,
192
191
  )
192
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
193
193
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
194
194
 
195
195
  logger = logging.getLogger(__name__)
196
196
 
197
197
  # Test retract decode for debugging purposes
198
- TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
198
+ TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
199
+ TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
199
200
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
200
201
 
201
202
 
202
- @dataclass
203
- class GenerationBatchResult:
204
- logits_output: Optional[LogitsProcessorOutput]
205
- pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
206
- next_token_ids: Optional[List[int]]
207
- can_run_cuda_graph: bool
208
-
209
- # For output processing
210
- extend_input_len_per_req: List[int]
211
- extend_logprob_start_len_per_req: List[int]
212
-
213
- @classmethod
214
- def from_forward_batch_output(
215
- cls,
216
- forward_batch_output: ForwardBatchOutput,
217
- extend_input_len_per_req: List[int],
218
- extend_logprob_start_len_per_req: List[int],
219
- ):
220
- # TODO(lsyin): remove this workaround logic and try to unify output classes
221
-
222
- return cls(
223
- logits_output=forward_batch_output.logits_output,
224
- pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
225
- next_token_ids=forward_batch_output.next_token_ids,
226
- extend_input_len_per_req=extend_input_len_per_req,
227
- extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
228
- can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
229
- )
230
-
231
- @classmethod
232
- def from_pp_proxy(
233
- cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
234
- ):
235
- # TODO(lsyin): also simplify this logic
236
- # Current PP implementation in scheduler is not compatible with ForwardBatchOutput
237
- # Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
238
- proxy_dict = next_pp_outputs.tensors
239
- return cls(
240
- logits_output=logits_output,
241
- pp_hidden_states_proxy_tensors=None,
242
- next_token_ids=next_pp_outputs["next_token_ids"],
243
- extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
244
- extend_logprob_start_len_per_req=proxy_dict.get(
245
- "extend_logprob_start_len_per_req", None
246
- ),
247
- can_run_cuda_graph=can_run_cuda_graph,
248
- )
249
-
250
-
251
203
  @dataclass
252
204
  class EmbeddingBatchResult:
253
205
  embeddings: torch.Tensor
@@ -260,6 +212,8 @@ class Scheduler(
260
212
  SchedulerMetricsMixin,
261
213
  SchedulerDisaggregationDecodeMixin,
262
214
  SchedulerDisaggregationPrefillMixin,
215
+ SchedulerRuntimeCheckerMixin,
216
+ SchedulerPPMixin,
263
217
  ):
264
218
  """A scheduler that manages a tensor parallel GPU worker."""
265
219
 
@@ -285,6 +239,9 @@ class Scheduler(
285
239
  self.dp_size = server_args.dp_size
286
240
  self.schedule_policy = server_args.schedule_policy
287
241
  self.enable_priority_scheduling = server_args.enable_priority_scheduling
242
+ self.abort_on_priority_when_disabled = (
243
+ server_args.abort_on_priority_when_disabled
244
+ )
288
245
  self.schedule_low_priority_values_first = (
289
246
  server_args.schedule_low_priority_values_first
290
247
  )
@@ -325,47 +282,7 @@ class Scheduler(
325
282
  self.model_config = ModelConfig.from_server_args(server_args)
326
283
 
327
284
  # Init inter-process communication
328
- context = zmq.Context(2)
329
- self.idle_sleeper = None
330
- if self.pp_rank == 0 and self.attn_tp_rank == 0:
331
- self.recv_from_tokenizer = get_zmq_socket(
332
- context, zmq.PULL, port_args.scheduler_input_ipc_name, False
333
- )
334
- self.recv_from_rpc = get_zmq_socket(
335
- context, zmq.DEALER, port_args.rpc_ipc_name, False
336
- )
337
-
338
- self.send_to_tokenizer = get_zmq_socket(
339
- context, zmq.PUSH, port_args.tokenizer_ipc_name, False
340
- )
341
- if server_args.skip_tokenizer_init:
342
- # Directly send to the TokenizerManager
343
- self.send_to_detokenizer = get_zmq_socket(
344
- context, zmq.PUSH, port_args.tokenizer_ipc_name, False
345
- )
346
- else:
347
- # Send to the DetokenizerManager
348
- self.send_to_detokenizer = get_zmq_socket(
349
- context, zmq.PUSH, port_args.detokenizer_ipc_name, False
350
- )
351
-
352
- if self.server_args.sleep_on_idle:
353
- self.idle_sleeper = IdleSleeper(
354
- [
355
- self.recv_from_tokenizer,
356
- self.recv_from_rpc,
357
- ]
358
- )
359
- else:
360
- self.recv_from_tokenizer = None
361
- self.recv_from_rpc = None
362
- self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
363
- self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
364
-
365
- if self.current_scheduler_metrics_enabled():
366
- self.send_metrics_from_scheduler = get_zmq_socket(
367
- context, zmq.PUSH, port_args.metrics_ipc_name, False
368
- )
285
+ self.init_sockets(server_args, port_args)
369
286
 
370
287
  # Init tokenizer
371
288
  self.init_tokenizer()
@@ -388,12 +305,10 @@ class Scheduler(
388
305
  logger.info("Overlap scheduler is disabled for embedding models.")
389
306
 
390
307
  # Launch a tensor parallel worker
391
- if self.enable_overlap:
392
- TpWorkerClass = TpModelWorkerClient
393
- else:
394
- TpWorkerClass = TpModelWorker
395
308
 
396
- self.tp_worker = TpWorkerClass(
309
+ from sglang.srt.managers.tp_worker import TpModelWorker
310
+
311
+ self.tp_worker = TpModelWorker(
397
312
  server_args=server_args,
398
313
  gpu_id=gpu_id,
399
314
  tp_rank=tp_rank,
@@ -404,44 +319,10 @@ class Scheduler(
404
319
  )
405
320
 
406
321
  # Launch a draft worker for speculative decoding
407
- if self.spec_algorithm.is_eagle():
408
- from sglang.srt.speculative.eagle_worker import EAGLEWorker
409
-
410
- self.draft_worker = EAGLEWorker(
411
- gpu_id=gpu_id,
412
- tp_rank=tp_rank,
413
- moe_ep_rank=moe_ep_rank,
414
- server_args=server_args,
415
- nccl_port=port_args.nccl_port,
416
- target_worker=self.tp_worker,
417
- dp_rank=dp_rank,
418
- )
419
- elif self.spec_algorithm.is_standalone():
420
- from sglang.srt.speculative.standalone_worker import StandaloneWorker
421
-
422
- self.draft_worker = StandaloneWorker(
423
- gpu_id=gpu_id,
424
- tp_rank=tp_rank,
425
- moe_ep_rank=moe_ep_rank,
426
- server_args=server_args,
427
- nccl_port=port_args.nccl_port,
428
- target_worker=self.tp_worker,
429
- dp_rank=dp_rank,
430
- )
431
- elif self.spec_algorithm.is_ngram():
432
- from sglang.srt.speculative.ngram_worker import NGRAMWorker
433
322
 
434
- self.draft_worker = NGRAMWorker(
435
- gpu_id=gpu_id,
436
- tp_rank=tp_rank,
437
- moe_ep_rank=moe_ep_rank,
438
- server_args=server_args,
439
- nccl_port=port_args.nccl_port,
440
- target_worker=self.tp_worker,
441
- dp_rank=dp_rank,
442
- )
443
- else:
444
- self.draft_worker = None
323
+ self.launch_draft_worker(
324
+ gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
325
+ )
445
326
 
446
327
  # Dispatch the model worker
447
328
  if self.spec_algorithm.is_none():
@@ -459,13 +340,12 @@ class Scheduler(
459
340
  self.max_req_input_len,
460
341
  self.random_seed,
461
342
  self.device,
462
- worker_global_server_args_dict,
463
343
  _,
464
344
  _,
465
345
  _,
466
346
  ) = self.tp_worker.get_worker_info()
467
- if global_server_args_dict["max_micro_batch_size"] is None:
468
- global_server_args_dict["max_micro_batch_size"] = max(
347
+ if get_global_server_args().pp_max_micro_batch_size is None:
348
+ get_global_server_args().pp_max_micro_batch_size = max(
469
349
  self.max_running_requests // server_args.pp_size, 1
470
350
  )
471
351
 
@@ -477,11 +357,12 @@ class Scheduler(
477
357
  self.world_group = get_world_group()
478
358
 
479
359
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
480
- global_server_args_dict.update(worker_global_server_args_dict)
481
360
  set_random_seed(self.random_seed)
482
361
 
483
362
  # Hybrid memory pool
484
363
  self.is_hybrid = self.tp_worker.is_hybrid
364
+ self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None
365
+
485
366
  if self.is_hybrid:
486
367
  self.sliding_window_size = self.tp_worker.sliding_window_size
487
368
  self.full_tokens_per_layer, self.swa_tokens_per_layer = (
@@ -525,9 +406,11 @@ class Scheduler(
525
406
  self.kv_transfer_speed_gb_s: float = 0.0
526
407
  self.kv_transfer_latency_ms: float = 0.0
527
408
  self.sessions: Dict[str, Session] = {}
528
- self.current_stream = torch.get_device_module(self.device).current_stream()
409
+ self.default_stream: CudaStream = torch.get_device_module(
410
+ self.device
411
+ ).current_stream()
529
412
  if self.device == "cpu":
530
- self.current_stream.synchronize = lambda: None # No-op for CPU
413
+ self.default_stream.synchronize = lambda: None # No-op for CPU
531
414
  self.forward_sleep_time = None
532
415
 
533
416
  # Init chunked prefill
@@ -566,18 +449,17 @@ class Scheduler(
566
449
  server_args.schedule_conservativeness >= 0
567
450
  ), "Invalid schedule_conservativeness"
568
451
  self.init_new_token_ratio = min(
569
- global_config.default_init_new_token_ratio
452
+ envs.SGLANG_INIT_NEW_TOKEN_RATIO.get()
570
453
  * server_args.schedule_conservativeness,
571
454
  1.0,
572
455
  )
573
456
  self.min_new_token_ratio = min(
574
- self.init_new_token_ratio
575
- * global_config.default_min_new_token_ratio_factor,
457
+ self.init_new_token_ratio * envs.SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR.get(),
576
458
  1.0,
577
459
  )
578
460
  self.new_token_ratio_decay = (
579
461
  self.init_new_token_ratio - self.min_new_token_ratio
580
- ) / global_config.default_new_token_ratio_decay_steps
462
+ ) / envs.SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS.get()
581
463
  self.new_token_ratio = self.init_new_token_ratio
582
464
 
583
465
  # Init watchdog thread
@@ -612,12 +494,15 @@ class Scheduler(
612
494
  )
613
495
  self.init_disaggregation()
614
496
 
615
- if get_bool_env_var("SGLANG_GC_LOG"):
497
+ if envs.SGLANG_LOG_GC.get():
616
498
  configure_gc_logger()
617
499
 
618
500
  # Init prefill kv split size when deterministic inference is enabled with various attention backends
619
501
  self.init_deterministic_inference_config()
620
502
 
503
+ # Init overlap
504
+ self.init_overlap()
505
+
621
506
  # Init request dispatcher
622
507
  self._request_dispatcher = TypeBasedDispatcher(
623
508
  [
@@ -646,6 +531,7 @@ class Scheduler(
646
531
  self.update_weights_from_distributed,
647
532
  ),
648
533
  (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
534
+ (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
649
535
  (GetWeightsByNameReqInput, self.get_weights_by_name),
650
536
  (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
651
537
  (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
@@ -658,11 +544,130 @@ class Scheduler(
658
544
  (ExpertDistributionReq, self.expert_distribution_handle),
659
545
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
660
546
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
661
- (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
662
547
  (GetLoadReqInput, self.get_load),
663
548
  ]
664
549
  )
665
550
 
551
+ def launch_draft_worker(
552
+ self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
553
+ ):
554
+ if server_args.speculative_draft_load_format is not None:
555
+ server_args.load_format = server_args.speculative_draft_load_format
556
+ logger.info(
557
+ f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
558
+ )
559
+
560
+ if self.spec_algorithm.is_eagle():
561
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker
562
+ from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
563
+
564
+ WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
565
+
566
+ self.draft_worker = WorkerClass(
567
+ gpu_id=gpu_id,
568
+ tp_rank=tp_rank,
569
+ moe_ep_rank=moe_ep_rank,
570
+ server_args=server_args,
571
+ nccl_port=port_args.nccl_port,
572
+ target_worker=self.tp_worker,
573
+ dp_rank=dp_rank,
574
+ )
575
+ elif self.spec_algorithm.is_standalone():
576
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
577
+
578
+ self.draft_worker = StandaloneWorker(
579
+ gpu_id=gpu_id,
580
+ tp_rank=tp_rank,
581
+ moe_ep_rank=moe_ep_rank,
582
+ server_args=server_args,
583
+ nccl_port=port_args.nccl_port,
584
+ target_worker=self.tp_worker,
585
+ dp_rank=dp_rank,
586
+ )
587
+ elif self.spec_algorithm.is_ngram():
588
+ from sglang.srt.speculative.ngram_worker import NGRAMWorker
589
+
590
+ self.draft_worker = NGRAMWorker(
591
+ gpu_id=gpu_id,
592
+ tp_rank=tp_rank,
593
+ moe_ep_rank=moe_ep_rank,
594
+ server_args=server_args,
595
+ nccl_port=port_args.nccl_port,
596
+ target_worker=self.tp_worker,
597
+ dp_rank=dp_rank,
598
+ )
599
+ else:
600
+ self.draft_worker = None
601
+
602
+ def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
603
+ context = zmq.Context(2)
604
+ self.idle_sleeper = None
605
+
606
+ class SenderWrapper:
607
+ def __init__(self, socket: zmq.Socket):
608
+ self.socket = socket
609
+
610
+ def send_output(
611
+ self,
612
+ output: Union[BaseReq, BaseBatchReq],
613
+ recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
614
+ ):
615
+ if self.socket is None:
616
+ return
617
+
618
+ if (
619
+ isinstance(recv_obj, BaseReq)
620
+ and recv_obj.http_worker_ipc is not None
621
+ and output.http_worker_ipc is None
622
+ ):
623
+ # handle communicator reqs for multi-http worker case
624
+ output.http_worker_ipc = recv_obj.http_worker_ipc
625
+
626
+ self.socket.send_pyobj(output)
627
+
628
+ if self.pp_rank == 0 and self.attn_tp_rank == 0:
629
+ self.recv_from_tokenizer = get_zmq_socket(
630
+ context, zmq.PULL, port_args.scheduler_input_ipc_name, False
631
+ )
632
+ self.recv_from_rpc = get_zmq_socket(
633
+ context, zmq.DEALER, port_args.rpc_ipc_name, False
634
+ )
635
+
636
+ send_to_tokenizer = get_zmq_socket(
637
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
638
+ )
639
+ if server_args.skip_tokenizer_init:
640
+ # Directly send to the TokenizerManager
641
+ send_to_detokenizer = get_zmq_socket(
642
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
643
+ )
644
+ else:
645
+ # Send to the DetokenizerManager
646
+ send_to_detokenizer = get_zmq_socket(
647
+ context, zmq.PUSH, port_args.detokenizer_ipc_name, False
648
+ )
649
+
650
+ self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
651
+ self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)
652
+
653
+ if self.server_args.sleep_on_idle:
654
+ self.idle_sleeper = IdleSleeper(
655
+ [
656
+ self.recv_from_tokenizer,
657
+ self.recv_from_rpc,
658
+ ]
659
+ )
660
+ else:
661
+ self.recv_from_tokenizer = None
662
+ self.recv_from_rpc = None
663
+ self.send_to_tokenizer = SenderWrapper(None)
664
+ self.send_to_detokenizer = SenderWrapper(None)
665
+
666
+ if self.current_scheduler_metrics_enabled():
667
+ self.send_metrics_from_scheduler = get_zmq_socket(
668
+ context, zmq.PUSH, port_args.metrics_ipc_name, False
669
+ )
670
+
666
671
  def init_deterministic_inference_config(self):
667
672
  """Initialize deterministic inference configuration for different attention backends."""
668
673
  if not self.server_args.enable_deterministic_inference:
@@ -768,15 +773,20 @@ class Scheduler(
768
773
  self.tree_cache.cache_controller.layer_done_counter
769
774
  )
770
775
  elif self.is_hybrid:
771
- assert (
772
- self.server_args.disaggregation_mode == "null"
773
- ), "Hybrid mode does not support disaggregation yet"
774
776
  self.tree_cache = SWARadixCache(
775
777
  req_to_token_pool=self.req_to_token_pool,
776
778
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
777
779
  sliding_window_size=self.sliding_window_size,
778
780
  page_size=self.page_size,
779
781
  disable=server_args.disable_radix_cache,
782
+ is_eagle=self.spec_algorithm.is_eagle(),
783
+ )
784
+ elif self.is_hybrid_gdn:
785
+ self.tree_cache = MambaRadixCache(
786
+ req_to_token_pool=self.req_to_token_pool,
787
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
788
+ page_size=self.page_size,
789
+ disable=server_args.disable_radix_cache,
780
790
  )
781
791
  elif server_args.enable_lmcache:
782
792
  from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
@@ -931,6 +941,34 @@ class Scheduler(
931
941
  # The prefill requests that are in the middle of kv sending
932
942
  self.disagg_prefill_inflight_queue: List[Req] = []
933
943
 
944
+ def init_overlap(self):
945
+ if not self.enable_overlap:
946
+ return
947
+
948
+ self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
949
+ self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
950
+ self.device
951
+ ).stream(self.forward_stream)
952
+ self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
953
+ self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
954
+ self.device
955
+ ).stream(self.copy_stream)
956
+
957
+ self.future_map = FutureMap(
958
+ self.max_running_requests, self.device, self.spec_algorithm
959
+ )
960
+ self.batch_record_buf = [None] * 2
961
+ self.batch_record_ct = 0
962
+
963
+ def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
964
+ # FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
965
+ # NOTE: More Reliable: record all tensors into the forward stream
966
+ # NOTE: - for all future tensors, we shall always read from future map
967
+ # - for all non-future tensors (produced only by schedule stream),
968
+ # we shall keep its reference not being release during all the forwarding pass
969
+ self.batch_record_ct = (self.batch_record_ct + 1) % 2
970
+ self.batch_record_buf[self.batch_record_ct] = model_worker_batch
971
+
934
972
  def init_moe_config(self):
935
973
  if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
936
974
  initialize_moe_config(self.server_args)
@@ -957,7 +995,7 @@ class Scheduler(
957
995
  @DynamicGradMode()
958
996
  def event_loop_overlap(self):
959
997
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
960
- self.result_queue = deque()
998
+ self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
961
999
 
962
1000
  while True:
963
1001
  recv_reqs = self.recv_requests()
@@ -966,158 +1004,24 @@ class Scheduler(
966
1004
  batch = self.get_next_batch_to_run()
967
1005
  self.cur_batch = batch
968
1006
 
1007
+ batch_result = None
969
1008
  if batch:
970
- batch.launch_done = threading.Event()
971
- result = self.run_batch(batch)
972
- self.result_queue.append((batch.copy(), result))
973
-
974
- if self.last_batch is None:
975
- # Create a dummy first batch to start the pipeline for overlap schedule.
976
- # It is now used for triggering the sampling_info_done event.
977
- tmp_batch = ScheduleBatch(
978
- reqs=None,
979
- forward_mode=ForwardMode.DUMMY_FIRST,
980
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
981
- )
982
- self.process_batch_result(tmp_batch, None, batch.launch_done)
1009
+ batch_result = self.run_batch(batch)
1010
+ self.result_queue.append((batch.copy(), batch_result))
983
1011
 
984
1012
  if self.last_batch:
985
1013
  # Process the results of the last batch
986
1014
  tmp_batch, tmp_result = self.result_queue.popleft()
987
- tmp_batch.next_batch_sampling_info = (
988
- self.tp_worker.cur_sampling_info if batch else None
989
- )
990
- # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
991
- self.process_batch_result(
992
- tmp_batch, tmp_result, batch.launch_done if batch else None
993
- )
1015
+ self.process_batch_result(tmp_batch, tmp_result)
994
1016
  elif batch is None:
995
1017
  # When the server is idle, do self-check and re-init some states
996
1018
  self.self_check_during_idle()
997
1019
 
1020
+ self.launch_batch_sample_if_needed(batch_result)
998
1021
  self.last_batch = batch
999
1022
 
1000
- @DynamicGradMode()
1001
- def event_loop_pp(self):
1002
- """A non-overlap scheduler loop for pipeline parallelism."""
1003
- mbs = [None] * self.pp_size
1004
- last_mbs = [None] * self.pp_size
1005
- self.running_mbs = [
1006
- ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
1007
- ]
1008
- pp_outputs: Optional[PPProxyTensors] = None
1009
- while True:
1010
- server_is_idle = True
1011
- for mb_id in range(self.pp_size):
1012
- self.running_batch = self.running_mbs[mb_id]
1013
- self.last_batch = last_mbs[mb_id]
1014
-
1015
- recv_reqs = self.recv_requests()
1016
- self.process_input_requests(recv_reqs)
1017
- mbs[mb_id] = self.get_next_batch_to_run()
1018
- self.running_mbs[mb_id] = self.running_batch
1019
-
1020
- self.cur_batch = mbs[mb_id]
1021
- if self.cur_batch:
1022
- server_is_idle = False
1023
- result = self.run_batch(self.cur_batch)
1024
-
1025
- # (last rank) send the outputs to the next step
1026
- if self.pp_group.is_last_rank:
1027
- if self.cur_batch:
1028
- next_token_ids = result.next_token_ids
1029
- if self.cur_batch.return_logprob:
1030
- pp_outputs = PPProxyTensors(
1031
- {
1032
- "next_token_ids": next_token_ids,
1033
- "extend_input_len_per_req": result.extend_input_len_per_req,
1034
- "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
1035
- }
1036
- | (
1037
- {
1038
- f"logits_output.{k}": v
1039
- for k, v in result.logits_output.__dict__.items()
1040
- }
1041
- if result.logits_output is not None
1042
- else {}
1043
- )
1044
- )
1045
- else:
1046
- pp_outputs = PPProxyTensors(
1047
- {
1048
- "next_token_ids": next_token_ids,
1049
- }
1050
- )
1051
- # send the output from the last round to let the next stage worker run post processing
1052
- self.pp_group.send_tensor_dict(
1053
- pp_outputs.tensors,
1054
- all_gather_group=self.attn_tp_group,
1055
- )
1056
-
1057
- # receive outputs and post-process (filter finished reqs) the coming microbatch
1058
- next_mb_id = (mb_id + 1) % self.pp_size
1059
- next_pp_outputs = None
1060
- if mbs[next_mb_id] is not None:
1061
- next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
1062
- self.pp_group.recv_tensor_dict(
1063
- all_gather_group=self.attn_tp_group
1064
- )
1065
- )
1066
- mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
1067
- logits_output_args = {
1068
- k[len("logits_output.") :]: v
1069
- for k, v in next_pp_outputs.tensors.items()
1070
- if k.startswith("logits_output.")
1071
- }
1072
- if len(logits_output_args) > 0:
1073
- logits_output = LogitsProcessorOutput(**logits_output_args)
1074
- else:
1075
- logits_output = None
1076
-
1077
- output_result = GenerationBatchResult.from_pp_proxy(
1078
- logits_output=logits_output,
1079
- next_pp_outputs=next_pp_outputs,
1080
- can_run_cuda_graph=result.can_run_cuda_graph,
1081
- )
1082
- self.process_batch_result(mbs[next_mb_id], output_result)
1083
- last_mbs[next_mb_id] = mbs[next_mb_id]
1084
-
1085
- # (not last rank)
1086
- if not self.pp_group.is_last_rank:
1087
- # carry the outputs to the next stage
1088
- # send the outputs from the last round to let the next stage worker run post processing
1089
- if pp_outputs:
1090
- self.pp_group.send_tensor_dict(
1091
- pp_outputs.tensors,
1092
- all_gather_group=self.attn_tp_group,
1093
- )
1094
-
1095
- # send out reqs to the next stage
1096
- dp_offset = self.attn_dp_rank * self.attn_tp_size
1097
- if self.attn_tp_rank == 0:
1098
- point_to_point_pyobj(
1099
- recv_reqs,
1100
- self.pp_rank * self.tp_size + dp_offset,
1101
- self.world_group.device_group,
1102
- self.pp_rank * self.tp_size + dp_offset,
1103
- (self.pp_rank + 1) * self.tp_size + dp_offset,
1104
- )
1105
-
1106
- # send out proxy tensors to the next stage
1107
- if self.cur_batch:
1108
- # FIXME(lsyin): remove this assert
1109
- assert result.pp_hidden_states_proxy_tensors.tensors is not None
1110
- self.pp_group.send_tensor_dict(
1111
- result.pp_hidden_states_proxy_tensors.tensors,
1112
- all_gather_group=self.attn_tp_group,
1113
- )
1114
-
1115
- pp_outputs = next_pp_outputs
1116
-
1117
- # When the server is idle, self-check and re-init some states
1118
- if server_is_idle:
1119
- # When the server is idle, do self-check and re-init some states
1120
- self.self_check_during_idle()
1023
+ if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
1024
+ self._check_runtime_mem_leak()
1121
1025
 
1122
1026
  def recv_requests(self) -> List[Req]:
1123
1027
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
@@ -1240,23 +1144,13 @@ class Scheduler(
1240
1144
  self.return_health_check_ct += 1
1241
1145
  continue
1242
1146
 
1243
- # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
1244
- if isinstance(recv_req, MultiTokenizerWrapper):
1245
- worker_id = recv_req.worker_id
1246
- recv_req = recv_req.obj
1247
- output = self._request_dispatcher(recv_req)
1248
- if output is not None:
1249
- output = MultiTokenizerWrapper(worker_id, output)
1250
- self.send_to_tokenizer.send_pyobj(output)
1251
- continue
1252
-
1253
1147
  output = self._request_dispatcher(recv_req)
1254
1148
  if output is not None:
1255
1149
  if isinstance(output, RpcReqOutput):
1256
1150
  if self.recv_from_rpc is not None:
1257
1151
  self.recv_from_rpc.send_pyobj(output)
1258
1152
  else:
1259
- self.send_to_tokenizer.send_pyobj(output)
1153
+ self.send_to_tokenizer.send_output(output, recv_req)
1260
1154
 
1261
1155
  def init_req_max_new_tokens(self, req):
1262
1156
  req.sampling_params.max_new_tokens = min(
@@ -1312,6 +1206,7 @@ class Scheduler(
1312
1206
  metrics_collector=(
1313
1207
  self.metrics_collector if self.enable_metrics else None
1314
1208
  ),
1209
+ http_worker_ipc=recv_req.http_worker_ipc,
1315
1210
  )
1316
1211
  req.tokenizer = self.tokenizer
1317
1212
 
@@ -1410,26 +1305,29 @@ class Scheduler(
1410
1305
  or req.sampling_params.ebnf is not None
1411
1306
  or req.sampling_params.structural_tag is not None
1412
1307
  ):
1413
- assert self.grammar_backend is not None
1414
- if req.sampling_params.json_schema is not None:
1415
- key = ("json", req.sampling_params.json_schema)
1416
- elif req.sampling_params.regex is not None:
1417
- key = ("regex", req.sampling_params.regex)
1418
- elif req.sampling_params.ebnf is not None:
1419
- key = ("ebnf", req.sampling_params.ebnf)
1420
- elif req.sampling_params.structural_tag:
1421
- key = ("structural_tag", req.sampling_params.structural_tag)
1422
-
1423
- value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
1424
- req.grammar = value
1425
-
1426
- if not cache_hit:
1427
- req.grammar_key = key
1428
- add_to_grammar_queue = True
1308
+ if self.grammar_backend is None:
1309
+ error_msg = "Grammar-based generation (json_schema, regex, ebnf, structural_tag) is not supported when the server is launched with --grammar-backend none"
1310
+ req.set_finish_with_abort(error_msg)
1429
1311
  else:
1430
- if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
1431
- error_msg = f"Invalid grammar request with cache hit: {key=}"
1432
- req.set_finish_with_abort(error_msg)
1312
+ if req.sampling_params.json_schema is not None:
1313
+ key = ("json", req.sampling_params.json_schema)
1314
+ elif req.sampling_params.regex is not None:
1315
+ key = ("regex", req.sampling_params.regex)
1316
+ elif req.sampling_params.ebnf is not None:
1317
+ key = ("ebnf", req.sampling_params.ebnf)
1318
+ elif req.sampling_params.structural_tag:
1319
+ key = ("structural_tag", req.sampling_params.structural_tag)
1320
+
1321
+ value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
1322
+ req.grammar = value
1323
+
1324
+ if not cache_hit:
1325
+ req.grammar_key = key
1326
+ add_to_grammar_queue = True
1327
+ else:
1328
+ if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
1329
+ error_msg = f"Invalid grammar request with cache hit: {key=}"
1330
+ req.set_finish_with_abort(error_msg)
1433
1331
 
1434
1332
  if add_to_grammar_queue:
1435
1333
  self.grammar_queue.append(req)
@@ -1456,8 +1354,18 @@ class Scheduler(
1456
1354
  last_hash = req.last_host_node.get_last_hash_value()
1457
1355
  matched_len = len(req.prefix_indices) + req.host_hit_length
1458
1356
  new_input_tokens = req.fill_ids[matched_len:]
1357
+
1358
+ prefix_keys = (
1359
+ req.last_node.get_prefix_hash_values(req.last_node.parent)
1360
+ if self.tree_cache.hicache_storage_pass_prefix_keys
1361
+ else None
1362
+ )
1459
1363
  self.tree_cache.prefetch_from_storage(
1460
- req.rid, req.last_host_node, new_input_tokens, last_hash
1364
+ req.rid,
1365
+ req.last_host_node,
1366
+ new_input_tokens,
1367
+ last_hash,
1368
+ prefix_keys,
1461
1369
  )
1462
1370
 
1463
1371
  def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
@@ -1489,7 +1397,11 @@ class Scheduler(
1489
1397
  req.priority = sys.maxsize
1490
1398
  else:
1491
1399
  req.priority = -sys.maxsize - 1
1492
- elif not self.enable_priority_scheduling and req.priority is not None:
1400
+ elif (
1401
+ not self.enable_priority_scheduling
1402
+ and req.priority is not None
1403
+ and self.abort_on_priority_when_disabled
1404
+ ):
1493
1405
  abort_req = AbortReq(
1494
1406
  finished_reason={
1495
1407
  "type": "abort",
@@ -1498,7 +1410,7 @@ class Scheduler(
1498
1410
  },
1499
1411
  rid=req.rid,
1500
1412
  )
1501
- self.send_to_tokenizer.send_pyobj(abort_req)
1413
+ self.send_to_tokenizer.send_output(abort_req, req)
1502
1414
 
1503
1415
  def _abort_on_queued_limit(self, recv_req: Req) -> bool:
1504
1416
  """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
@@ -1530,7 +1442,7 @@ class Scheduler(
1530
1442
  req_to_abort = candidate_req
1531
1443
  message = "The request is aborted by a higher priority request."
1532
1444
 
1533
- self.send_to_tokenizer.send_pyobj(
1445
+ self.send_to_tokenizer.send_output(
1534
1446
  AbortReq(
1535
1447
  finished_reason={
1536
1448
  "type": "abort",
@@ -1538,7 +1450,8 @@ class Scheduler(
1538
1450
  "message": message,
1539
1451
  },
1540
1452
  rid=req_to_abort.rid,
1541
- )
1453
+ ),
1454
+ req_to_abort,
1542
1455
  )
1543
1456
  return req_to_abort.rid == recv_req.rid
1544
1457
 
@@ -1553,6 +1466,7 @@ class Scheduler(
1553
1466
  recv_req.sampling_params,
1554
1467
  token_type_ids=recv_req.token_type_ids,
1555
1468
  priority=recv_req.priority,
1469
+ http_worker_ipc=recv_req.http_worker_ipc,
1556
1470
  )
1557
1471
  req.tokenizer = self.tokenizer
1558
1472
 
@@ -1602,109 +1516,6 @@ class Scheduler(
1602
1516
  for tokenized_req in recv_req:
1603
1517
  self.handle_embedding_request(tokenized_req)
1604
1518
 
1605
- def self_check_during_idle(self):
1606
- self.check_memory()
1607
- self.check_tree_cache()
1608
- self.new_token_ratio = self.init_new_token_ratio
1609
- self.maybe_sleep_on_idle()
1610
-
1611
- def check_memory(self):
1612
- if self.is_hybrid:
1613
- (
1614
- full_num_used,
1615
- swa_num_used,
1616
- _,
1617
- _,
1618
- full_available_size,
1619
- full_evictable_size,
1620
- swa_available_size,
1621
- swa_evictable_size,
1622
- ) = self._get_swa_token_info()
1623
- memory_leak = full_num_used != 0 or swa_num_used != 0
1624
- token_msg = (
1625
- f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
1626
- f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
1627
- )
1628
- else:
1629
- _, _, available_size, evictable_size = self._get_token_info()
1630
- protected_size = self.tree_cache.protected_size()
1631
- memory_leak = (available_size + evictable_size) != (
1632
- # self.max_total_num_tokens
1633
- # if not self.enable_hierarchical_cache
1634
- # else self.max_total_num_tokens - protected_size
1635
- self.max_total_num_tokens
1636
- - protected_size
1637
- )
1638
- token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
1639
-
1640
- if memory_leak:
1641
- msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
1642
- raise ValueError(msg)
1643
-
1644
- if self.disaggregation_mode == DisaggregationMode.DECODE:
1645
- req_total_size = (
1646
- self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
1647
- )
1648
- else:
1649
- req_total_size = self.req_to_token_pool.size
1650
-
1651
- if len(self.req_to_token_pool.free_slots) != req_total_size:
1652
- msg = (
1653
- "req_to_token_pool memory leak detected!"
1654
- f"available_size={len(self.req_to_token_pool.free_slots)}, "
1655
- f"total_size={self.req_to_token_pool.size}\n"
1656
- )
1657
- raise ValueError(msg)
1658
-
1659
- if (
1660
- self.enable_metrics
1661
- and self.current_scheduler_metrics_enabled()
1662
- and time.perf_counter() > self.metrics_collector.last_log_time + 30
1663
- ):
1664
- # During idle time, also collect metrics every 30 seconds.
1665
- if self.is_hybrid:
1666
- (
1667
- full_num_used,
1668
- swa_num_used,
1669
- full_token_usage,
1670
- swa_token_usage,
1671
- _,
1672
- _,
1673
- _,
1674
- _,
1675
- ) = self._get_swa_token_info()
1676
- num_used = max(full_num_used, swa_num_used)
1677
- token_usage = max(full_token_usage, swa_token_usage)
1678
- else:
1679
- num_used, token_usage, _, _ = self._get_token_info()
1680
- num_running_reqs = len(self.running_batch.reqs)
1681
- self.stats.num_running_reqs = num_running_reqs
1682
- self.stats.num_used_tokens = num_used
1683
- self.stats.token_usage = round(token_usage, 2)
1684
- self.stats.gen_throughput = 0
1685
- self.stats.num_queue_reqs = len(self.waiting_queue)
1686
- self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1687
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
1688
- self.stats.num_prefill_prealloc_queue_reqs = len(
1689
- self.disagg_prefill_bootstrap_queue.queue
1690
- )
1691
- self.stats.num_prefill_inflight_queue_reqs = len(
1692
- self.disagg_prefill_inflight_queue
1693
- )
1694
- if self.disaggregation_mode == DisaggregationMode.DECODE:
1695
- self.stats.num_decode_prealloc_queue_reqs = len(
1696
- self.disagg_decode_prealloc_queue.queue
1697
- )
1698
- self.stats.num_decode_transfer_queue_reqs = len(
1699
- self.disagg_decode_transfer_queue.queue
1700
- )
1701
- self.metrics_collector.log_stats(self.stats)
1702
- self._publish_kv_events()
1703
-
1704
- def check_tree_cache(self):
1705
- if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
1706
- self.tree_cache.sanity_check()
1707
-
1708
1519
  def _get_token_info(self):
1709
1520
  available_size = self.token_to_kv_pool_allocator.available_size()
1710
1521
  evictable_size = self.tree_cache.evictable_size()
@@ -1712,6 +1523,35 @@ class Scheduler(
1712
1523
  token_usage = num_used / self.max_total_num_tokens
1713
1524
  return num_used, token_usage, available_size, evictable_size
1714
1525
 
1526
+ def _get_mamba_token_info(self):
1527
+ is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
1528
+ full_available_size = self.token_to_kv_pool_allocator.available_size()
1529
+ full_evictable_size = (
1530
+ self.tree_cache.full_evictable_size() if is_radix_tree else 0
1531
+ )
1532
+ mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
1533
+ mamba_evictable_size = (
1534
+ self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
1535
+ )
1536
+ full_num_used = self.token_to_kv_pool_allocator.size - (
1537
+ full_available_size + full_evictable_size
1538
+ )
1539
+ mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
1540
+ mamba_available_size + mamba_evictable_size
1541
+ )
1542
+ full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
1543
+ mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
1544
+ return (
1545
+ full_num_used,
1546
+ mamba_num_used,
1547
+ full_token_usage,
1548
+ mamba_usage,
1549
+ full_available_size,
1550
+ full_evictable_size,
1551
+ mamba_available_size,
1552
+ mamba_evictable_size,
1553
+ )
1554
+
1715
1555
  def _get_swa_token_info(self):
1716
1556
  full_available_size = self.token_to_kv_pool_allocator.full_available_size()
1717
1557
  full_evictable_size = self.tree_cache.full_evictable_size()
@@ -1745,7 +1585,7 @@ class Scheduler(
1745
1585
  chunked_req_to_exclude.add(self.chunked_req)
1746
1586
  self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1747
1587
  # chunked request keeps its rid but will get a new req_pool_idx
1748
- if self.tp_worker.worker.model_runner.is_hybrid_gdn:
1588
+ if self.tp_worker.model_runner.mambaish_config is not None:
1749
1589
  self.req_to_token_pool.free(
1750
1590
  self.chunked_req.req_pool_idx, free_mamba_cache=False
1751
1591
  )
@@ -1802,7 +1642,7 @@ class Scheduler(
1802
1642
  return ret
1803
1643
 
1804
1644
  def get_num_allocatable_reqs(self, running_bs):
1805
- res = global_server_args_dict["max_micro_batch_size"] - running_bs
1645
+ res = get_global_server_args().pp_max_micro_batch_size - running_bs
1806
1646
  if self.pp_size > 1:
1807
1647
  res = min(res, self.req_to_token_pool.available_size())
1808
1648
  return res
@@ -1999,7 +1839,7 @@ class Scheduler(
1999
1839
 
2000
1840
  # Check if decode out of memory
2001
1841
  if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
2002
- TEST_RETRACT and batch.batch_size() > 10
1842
+ TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
2003
1843
  ):
2004
1844
  old_ratio = self.new_token_ratio
2005
1845
  retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
@@ -2008,8 +1848,8 @@ class Scheduler(
2008
1848
  self.num_retracted_reqs = len(retracted_reqs)
2009
1849
  self.new_token_ratio = new_token_ratio
2010
1850
  for req in reqs_to_abort:
2011
- self.send_to_tokenizer.send_pyobj(
2012
- AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
1851
+ self.send_to_tokenizer.send_output(
1852
+ AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
2013
1853
  )
2014
1854
 
2015
1855
  logger.info(
@@ -2034,6 +1874,12 @@ class Scheduler(
2034
1874
  batch.prepare_for_decode()
2035
1875
  return batch
2036
1876
 
1877
+ # placeholder for override
1878
+ def update_cache_from_scheduler(
1879
+ self, schedule_batch: ScheduleBatch, batch_result: GenerationBatchResult
1880
+ ):
1881
+ pass
1882
+
2037
1883
  def run_batch(
2038
1884
  self, batch: ScheduleBatch
2039
1885
  ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
@@ -2051,22 +1897,72 @@ class Scheduler(
2051
1897
 
2052
1898
  batch_or_worker_batch = batch
2053
1899
 
2054
- if self.spec_algorithm.is_none():
1900
+ if self.enable_overlap or self.spec_algorithm.is_none():
2055
1901
  # FIXME(lsyin): remove this if and finally unify the abstraction
2056
1902
  batch_or_worker_batch = batch.get_model_worker_batch()
2057
1903
 
2058
- forward_batch_output = self.model_worker.forward_batch_generation(
2059
- batch_or_worker_batch
2060
- )
1904
+ if self.enable_overlap:
1905
+ # FIXME: remove this assert
1906
+ assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
1907
+ model_worker_batch = batch_or_worker_batch
1908
+ self.record_batch_in_overlap(model_worker_batch)
1909
+
1910
+ # Sampling info will be modified during forward
1911
+ model_worker_batch.sampling_info = (
1912
+ model_worker_batch.sampling_info.copy_for_forward()
1913
+ )
1914
+
1915
+ bs = len(model_worker_batch.seq_lens)
1916
+ future_indices = self.future_map.alloc_future_indices(bs)
1917
+
1918
+ with self.forward_stream_ctx:
1919
+ self.forward_stream.wait_stream(self.default_stream)
1920
+ self.future_map.resolve_future(model_worker_batch)
1921
+ batch_result = self.model_worker.forward_batch_generation(
1922
+ model_worker_batch
1923
+ )
1924
+ # FIXME(lsyin): maybe move this to forward_batch_generation
1925
+ batch_result.copy_done = torch.get_device_module(
1926
+ self.device
1927
+ ).Event()
1928
+ if batch_result.delay_sample_func is None:
1929
+ self.future_map.store_to_map(future_indices, batch_result)
1930
+ batch_result.copy_to_cpu()
1931
+ else:
1932
+ batch_result.future_indices = future_indices
1933
+
1934
+ # FIXME(lsyin): move this assignment elsewhere
1935
+ future_indices_or_next_token_ids = -future_indices.indices
1936
+
1937
+ if batch.is_v2_eagle:
1938
+ # FIXME(lsyin): tmp code for eagle v2
1939
+ # We only keep future indices for next draft input
2061
1940
 
2062
- if not self.spec_algorithm.is_none():
2063
- # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
2064
- self.udpate_spec_metrics(
2065
- batch.batch_size(), forward_batch_output.num_accepted_tokens
1941
+ batch.spec_info = batch_result.next_draft_input
1942
+ batch.spec_info.future_indices = future_indices
1943
+
1944
+ # batch.spec_info = EagleDraftInput(
1945
+ # future_indices=future_indices,
1946
+ # verify_done=batch_result.next_draft_input.verify_done,
1947
+ # # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
1948
+ # allocate_lens=batch_result.next_draft_input.allocate_lens,
1949
+ # )
1950
+
1951
+ # The future value, usually for next batch preparation
1952
+ # Current implementation strictly synchronizes the seq_lens
1953
+ batch.seq_lens = batch_result.next_draft_input.new_seq_lens
1954
+ else:
1955
+ batch_result = self.model_worker.forward_batch_generation(
1956
+ batch_or_worker_batch
2066
1957
  )
1958
+ future_indices_or_next_token_ids = batch_result.next_token_ids
1959
+ self.update_cache_from_scheduler(batch, batch_result)
2067
1960
 
2068
- # update batch's output ids
2069
- batch.output_ids = forward_batch_output.next_token_ids
1961
+ # NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
1962
+ # which can probably be replaced by future_indices later [TODO(lsyin)].
1963
+ # we shall still keep the original outputs, e.g. next_token_ids
1964
+ # in the GenerationBatchOutput for processing after copy_done.
1965
+ batch.output_ids = future_indices_or_next_token_ids
2070
1966
 
2071
1967
  # These 2 values are needed for processing the output, but the values can be
2072
1968
  # modified by overlap schedule. So we have to copy them here so that
@@ -2083,39 +1979,51 @@ class Scheduler(
2083
1979
  else:
2084
1980
  extend_logprob_start_len_per_req = None
2085
1981
 
2086
- return GenerationBatchResult.from_forward_batch_output(
2087
- forward_batch_output=forward_batch_output,
2088
- extend_input_len_per_req=extend_input_len_per_req,
2089
- extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1982
+ batch_result.extend_input_len_per_req = extend_input_len_per_req
1983
+ batch_result.extend_logprob_start_len_per_req = (
1984
+ extend_logprob_start_len_per_req
2090
1985
  )
1986
+ return batch_result
2091
1987
  else: # embedding or reward model
2092
1988
  model_worker_batch = batch.get_model_worker_batch()
2093
1989
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
2094
1990
  ret = EmbeddingBatchResult(embeddings=embeddings)
2095
1991
  return ret
2096
1992
 
1993
+ def launch_batch_sample_if_needed(
1994
+ self, batch_result: GenerationBatchResult
1995
+ ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1996
+ # TODO(lsyin): make the delayed sample a default behavior after
1997
+ # unifying the forward_batch_generation interface (related to spec V2).
1998
+ if batch_result is None or batch_result.delay_sample_func is None:
1999
+ return
2000
+
2001
+ with self.forward_stream_ctx:
2002
+ self.forward_stream.wait_stream(self.default_stream)
2003
+ _batch_result = batch_result.delay_sample_func()
2004
+ assert _batch_result is batch_result
2005
+ self.future_map.store_to_map(batch_result.future_indices, batch_result)
2006
+ batch_result.copy_to_cpu()
2007
+
2097
2008
  def process_batch_result(
2098
2009
  self,
2099
2010
  batch: ScheduleBatch,
2100
2011
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
2101
- launch_done: Optional[threading.Event] = None,
2102
2012
  ):
2103
2013
  if batch.forward_mode.is_decode():
2104
- self.process_batch_result_decode(batch, result, launch_done)
2014
+ self.process_batch_result_decode(batch, result)
2105
2015
  if self.enable_trace:
2106
2016
  trace_slice_batch("decode loop", batch.reqs)
2107
2017
 
2108
2018
  elif batch.forward_mode.is_extend():
2109
- self.process_batch_result_prefill(batch, result, launch_done)
2019
+ self.process_batch_result_prefill(batch, result)
2110
2020
  if self.enable_trace:
2111
2021
  trace_slice_batch("prefill", batch.reqs)
2112
2022
 
2113
2023
  elif batch.forward_mode.is_idle():
2114
2024
  if self.enable_overlap:
2115
- self.tp_worker.resolve_last_batch_result(launch_done)
2116
- self.set_next_batch_sampling_info_done(batch)
2117
- elif batch.forward_mode.is_dummy_first():
2118
- self.set_next_batch_sampling_info_done(batch)
2025
+ if result.copy_done is not None:
2026
+ result.copy_done.synchronize()
2119
2027
 
2120
2028
  self.maybe_send_health_check_signal()
2121
2029
 
@@ -2125,7 +2033,7 @@ class Scheduler(
2125
2033
  # This is used to prevent the health check signal being blocked by long context prefill.
2126
2034
  # However, one minor issue is that this code path does not check the status of detokenizer manager.
2127
2035
  self.return_health_check_ct -= 1
2128
- self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
2036
+ self.send_to_tokenizer.send_output(HealthCheckOutput())
2129
2037
 
2130
2038
  def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
2131
2039
  return self.prepare_mlp_sync_batch_raw(
@@ -2139,6 +2047,7 @@ class Scheduler(
2139
2047
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
2140
2048
  require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
2141
2049
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
2050
+ offload_tags=self.offload_tags,
2142
2051
  )
2143
2052
 
2144
2053
  @staticmethod
@@ -2153,6 +2062,7 @@ class Scheduler(
2153
2062
  speculative_num_draft_tokens,
2154
2063
  require_mlp_tp_gather: bool,
2155
2064
  disable_overlap_schedule: bool,
2065
+ offload_tags: set[str],
2156
2066
  ):
2157
2067
  # Check if other DP workers have running batches
2158
2068
  if local_batch is None:
@@ -2163,15 +2073,18 @@ class Scheduler(
2163
2073
  num_tokens_for_logprob = num_tokens
2164
2074
  else:
2165
2075
  num_tokens = local_batch.extend_num_tokens
2166
- num_tokens_for_logprob = sum(
2167
- [
2076
+ if local_batch.return_logprob:
2077
+ num_tokens_for_logprob = sum(
2168
2078
  # We should have at least 1 token for sample in every case.
2169
2079
  max(extend_len - logprob_start_len, 1)
2170
2080
  for logprob_start_len, extend_len in zip(
2171
- local_batch.extend_logprob_start_lens, local_batch.extend_lens
2081
+ local_batch.extend_logprob_start_lens,
2082
+ local_batch.extend_lens,
2172
2083
  )
2173
- ]
2174
- )
2084
+ )
2085
+ else:
2086
+ # When return_logprob = False, only need last token per request
2087
+ num_tokens_for_logprob = local_batch.batch_size()
2175
2088
 
2176
2089
  if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
2177
2090
  can_cuda_graph = 1
@@ -2183,7 +2096,7 @@ class Scheduler(
2183
2096
  )
2184
2097
 
2185
2098
  tbo_preparer = TboDPAttentionPreparer()
2186
- if disable_overlap_schedule:
2099
+ if len(offload_tags) == 0 and disable_overlap_schedule:
2187
2100
  group = tp_group.device_group
2188
2101
  device = tp_group.device
2189
2102
  else:
@@ -2325,13 +2238,6 @@ class Scheduler(
2325
2238
  self._add_request_to_queue(req)
2326
2239
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
2327
2240
 
2328
- def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
2329
- if batch.next_batch_sampling_info:
2330
- if batch.next_batch_sampling_info.grammars is not None:
2331
- batch.next_batch_sampling_info.update_regex_vocab_mask()
2332
- self.current_stream.synchronize()
2333
- batch.next_batch_sampling_info.sampling_info_done.set()
2334
-
2335
2241
  def watchdog_thread(self):
2336
2242
  """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
2337
2243
  self.watchdog_last_forward_ct = 0
@@ -2419,10 +2325,10 @@ class Scheduler(
2419
2325
 
2420
2326
  self.num_generated_tokens = 0
2421
2327
  self.forward_ct_decode = 0
2422
- self.spec_num_total_accepted_tokens = 0
2423
- self.spec_num_total_forward_ct = 0
2424
- self.cum_spec_accept_length = 0
2425
- self.cum_spec_accept_count = 0
2328
+ self.spec_num_accepted_tokens = 0
2329
+ self.spec_num_forward_ct = 0
2330
+ self.spec_total_num_accepted_tokens = 0
2331
+ self.spec_total_num_forward_ct = 0
2426
2332
  torch.cuda.empty_cache()
2427
2333
  logger.info("Cache flushed successfully!")
2428
2334
  if_success = True
@@ -2481,12 +2387,10 @@ class Scheduler(
2481
2387
  )
2482
2388
 
2483
2389
  def get_internal_state(self, recv_req: GetInternalStateReq):
2484
- ret = dict(global_server_args_dict)
2390
+ ret = vars(get_global_server_args())
2485
2391
  ret["last_gen_throughput"] = self.last_gen_throughput
2486
2392
  ret["memory_usage"] = {
2487
- "weight": round(
2488
- self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
2489
- ),
2393
+ "weight": round(self.tp_worker.model_runner.weight_load_mem_usage, 2),
2490
2394
  "kvcache": round(
2491
2395
  self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
2492
2396
  ),
@@ -2494,23 +2398,26 @@ class Scheduler(
2494
2398
  }
2495
2399
 
2496
2400
  ret["memory_usage"]["graph"] = round(
2497
- self.tp_worker.worker.model_runner.graph_mem_usage, 2
2401
+ self.tp_worker.model_runner.graph_mem_usage, 2
2498
2402
  )
2499
2403
 
2500
- if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2404
+ if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
2501
2405
  ret["avg_spec_accept_length"] = (
2502
- self.cum_spec_accept_length / self.cum_spec_accept_count
2406
+ self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
2503
2407
  )
2504
2408
  if RECORD_STEP_TIME:
2505
2409
  ret["step_time_dict"] = self.step_time_dict
2506
2410
 
2411
+ # This field is not serializable.
2412
+ ret.pop("model_config", None)
2413
+
2507
2414
  return GetInternalStateReqOutput(internal_state=ret)
2508
2415
 
2509
2416
  def set_internal_state(self, recv_req: SetInternalStateReq):
2510
2417
  server_args_dict = recv_req.server_args
2511
2418
  args_allow_update = set(
2512
2419
  [
2513
- "max_micro_batch_size",
2420
+ "pp_max_micro_batch_size",
2514
2421
  "speculative_accept_threshold_single",
2515
2422
  "speculative_accept_threshold_acc",
2516
2423
  ]
@@ -2521,7 +2428,7 @@ class Scheduler(
2521
2428
  logging.warning(f"Updating {k} is not supported.")
2522
2429
  if_success = False
2523
2430
  break
2524
- elif k == "max_micro_batch_size" and (
2431
+ elif k == "pp_max_micro_batch_size" and (
2525
2432
  v > self.max_running_requests // self.pp_size or v < 1
2526
2433
  ):
2527
2434
  logging.warning(
@@ -2530,18 +2437,18 @@ class Scheduler(
2530
2437
  if_success = False
2531
2438
  break
2532
2439
  if if_success:
2533
- if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2440
+ if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
2534
2441
  avg_spec_accept_length = (
2535
- self.cum_spec_accept_length / self.cum_spec_accept_count
2442
+ self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
2536
2443
  )
2537
2444
  logger.info(f"{avg_spec_accept_length=}")
2538
- self.cum_spec_accept_length = self.cum_spec_accept_count = 0
2445
+ self.spec_total_num_accepted_tokens = self.spec_total_num_forward_ct = 0
2539
2446
  for k, v in server_args_dict.items():
2540
- global_server_args_dict[k] = v
2541
- logger.info(f"Global server args updated! {global_server_args_dict=}")
2447
+ setattr(get_global_server_args(), k, v)
2448
+ logger.info(f"Global server args updated! {get_global_server_args()=}")
2542
2449
  return SetInternalStateReqOutput(
2543
2450
  updated=True,
2544
- server_args=global_server_args_dict,
2451
+ server_args=vars(get_global_server_args()),
2545
2452
  )
2546
2453
 
2547
2454
  def handle_rpc_request(self, recv_req: RpcReqInput):
@@ -2579,7 +2486,7 @@ class Scheduler(
2579
2486
  if self.enable_hicache_storage:
2580
2487
  # to release prefetch events associated with the request
2581
2488
  self.tree_cache.release_aborted_request(req.rid)
2582
- self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
2489
+ self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
2583
2490
  # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2584
2491
  if self.disaggregation_mode == DisaggregationMode.DECODE:
2585
2492
  self.tree_cache.cache_finished_req(req)
@@ -2663,10 +2570,6 @@ class Scheduler(
2663
2570
  result = self.tp_worker.unload_lora_adapter(recv_req)
2664
2571
  return result
2665
2572
 
2666
- def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
2667
- self.send_to_detokenizer.send_pyobj(recv_req)
2668
- return recv_req
2669
-
2670
2573
  def init_weights_send_group_for_remote_instance(
2671
2574
  self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
2672
2575
  ):
@@ -2745,7 +2648,7 @@ class Scheduler(
2745
2648
  def handle_freeze_gc(self, recv_req: FreezeGCReq):
2746
2649
  """Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
2747
2650
  freeze_gc("Scheduler")
2748
- self.send_to_detokenizer.send_pyobj(recv_req)
2651
+ self.send_to_detokenizer.send_output(recv_req, recv_req)
2749
2652
  return None
2750
2653
 
2751
2654
 
@@ -2767,12 +2670,13 @@ class IdleSleeper:
2767
2670
  for s in sockets:
2768
2671
  self.poller.register(s, zmq.POLLIN)
2769
2672
 
2673
+ self.empty_cache_interval = envs.SGLANG_EMPTY_CACHE_INTERVAL.get()
2674
+
2770
2675
  def maybe_sleep(self):
2771
2676
  self.poller.poll(1000)
2772
2677
  if (
2773
- global_config.torch_empty_cache_interval > 0
2774
- and time.time() - self.last_empty_time
2775
- > global_config.torch_empty_cache_interval
2678
+ self.empty_cache_interval > 0
2679
+ and time.time() - self.last_empty_time > self.empty_cache_interval
2776
2680
  ):
2777
2681
  self.last_empty_time = time.time()
2778
2682
  torch.cuda.empty_cache()
@@ -2831,7 +2735,9 @@ def run_scheduler_process(
2831
2735
 
2832
2736
  # Set cpu affinity to this gpu process
2833
2737
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2834
- set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2738
+ set_gpu_proc_affinity(
2739
+ server_args.pp_size, server_args.tp_size, server_args.nnodes, gpu_id
2740
+ )
2835
2741
  if (numa_node := server_args.numa_node) is not None:
2836
2742
  numa_bind_to_node(numa_node[gpu_id])
2837
2743