sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,12 @@
1
- import hashlib
2
1
  import logging
3
2
  import os
4
3
  import time
5
4
  import uuid
6
- from typing import Any, Dict, List, Optional, Tuple, Union
5
+ from typing import Any, List, Optional, Union
7
6
 
8
7
  import torch
9
8
 
10
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
9
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
11
10
 
12
11
  from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
13
12
 
@@ -26,7 +25,12 @@ logger = logging.getLogger(__name__)
26
25
  class HiCacheNixl(HiCacheStorage):
27
26
  """HiCacheNixl provides high-performance storage using NIXL plugins."""
28
27
 
29
- def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
28
+ def __init__(
29
+ self,
30
+ storage_config: HiCacheStorageConfig,
31
+ file_path: str = "/tmp/hicache_storage",
32
+ plugin: str = "auto",
33
+ ):
30
34
  """Initialize NIXL storage connector."""
31
35
  # Might be better to be unified across HiCache backends and moved to HiCacheController
32
36
  file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
@@ -36,6 +40,19 @@ class HiCacheNixl(HiCacheStorage):
36
40
  else None
37
41
  )
38
42
 
43
+ # Initialize suffix based on storage config
44
+ tp_rank, tp_size, model_name, is_mla_model = (
45
+ storage_config.tp_rank,
46
+ storage_config.tp_size,
47
+ storage_config.model_name,
48
+ storage_config.is_mla_model,
49
+ )
50
+ model_name = "-".join(model_name.split("/")) if model_name else ""
51
+ if is_mla_model:
52
+ self.config_suffix = f"_{model_name}"
53
+ else:
54
+ self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
55
+
39
56
  agent_config = nixl_agent_config(backends=[])
40
57
  self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
41
58
  self.agent = nixl_agent(self.agent_name, agent_config)
@@ -46,6 +63,9 @@ class HiCacheNixl(HiCacheStorage):
46
63
 
47
64
  self.registration = NixlRegistration(self.agent)
48
65
 
66
+ def _get_suffixed_key(self, key: str) -> str:
67
+ return key + self.config_suffix
68
+
49
69
  def register_buffers(
50
70
  self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
51
71
  ) -> Optional[Any]:
@@ -194,11 +214,14 @@ class HiCacheNixl(HiCacheStorage):
194
214
  else:
195
215
  dest = target_locations
196
216
 
217
+ # Add suffix to keys
218
+ suffixed_keys = [self._get_suffixed_key(key) for key in keys]
219
+
197
220
  if self.backend_selector.mem_type == "FILE":
198
- file_paths = [self.file_manager.get_file_path(key) for key in keys]
221
+ file_paths = [self.file_manager.get_file_path(key) for key in suffixed_keys]
199
222
  success = self._execute_transfer(dest, file_paths, "READ")
200
223
  else:
201
- success = self._execute_transfer(dest, keys, "READ")
224
+ success = self._execute_transfer(dest, suffixed_keys, "READ")
202
225
  return target_locations if success and not target_sizes else [None] * len(keys)
203
226
 
204
227
  def set(
@@ -227,9 +250,12 @@ class HiCacheNixl(HiCacheStorage):
227
250
  if not values:
228
251
  values = list(zip(target_locations, target_sizes))
229
252
 
253
+ # Add suffix to keys
254
+ suffixed_keys = [self._get_suffixed_key(key) for key in keys]
255
+
230
256
  if self.backend_selector.mem_type == "FILE":
231
257
  file_paths = []
232
- for key in keys:
258
+ for key in suffixed_keys:
233
259
  file_path = self.file_manager.get_file_path(key)
234
260
  # New file per set, to be updated when partial writes is added to HiCache
235
261
  if not self.file_manager.create_file(file_path):
@@ -238,11 +264,14 @@ class HiCacheNixl(HiCacheStorage):
238
264
  file_paths.append(file_path)
239
265
  return self._execute_transfer(values, file_paths, "WRITE")
240
266
  else: # mem_type == "OBJ"
241
- return self._execute_transfer(values, keys, "WRITE")
267
+ return self._execute_transfer(values, suffixed_keys, "WRITE")
242
268
 
243
269
  def exists(self, key: str) -> bool:
270
+ # Add suffix to key
271
+ suffixed_key = self._get_suffixed_key(key)
272
+
244
273
  tuples = self.registration.create_query_tuples(
245
- key,
274
+ suffixed_key,
246
275
  self.backend_selector.mem_type,
247
276
  self.file_manager if self.backend_selector.mem_type == "FILE" else None,
248
277
  )
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  import os
3
- from typing import Any, Dict, List, Optional, Tuple, Union
3
+ from typing import Any, List, Optional, Tuple, Union
4
4
 
5
5
  import torch
6
6
 
@@ -2,11 +2,12 @@
2
2
 
3
3
  import os
4
4
  import unittest
5
- from typing import List, Optional
5
+ from typing import List
6
6
  from unittest.mock import MagicMock
7
7
 
8
8
  import torch
9
9
 
10
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
10
11
  from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
11
12
  from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
12
13
  NixlFileManager,
@@ -31,8 +32,22 @@ class TestNixlUnified(unittest.TestCase):
31
32
  # Create instances
32
33
  self.file_manager = NixlFileManager(self.test_dir)
33
34
  self.registration = NixlRegistration(self.mock_agent)
35
+
36
+ # Create storage config for testing
37
+ self.storage_config = HiCacheStorageConfig(
38
+ tp_rank=0,
39
+ tp_size=2,
40
+ is_mla_model=False,
41
+ is_page_first_layout=False,
42
+ model_name="test_model",
43
+ )
44
+
34
45
  try:
35
- self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX")
46
+ self.hicache = HiCacheNixl(
47
+ storage_config=self.storage_config,
48
+ file_path=self.test_dir,
49
+ plugin="POSIX",
50
+ )
36
51
  except ImportError:
37
52
  self.skipTest("NIXL not available, skipping NIXL storage tests")
38
53
 
@@ -32,6 +32,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
32
32
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
33
33
  from sglang.srt.mem_cache.radix_cache import (
34
34
  RadixKey,
35
+ _convert_to_bigram_key,
35
36
  _key_match_page_size1,
36
37
  _key_match_paged,
37
38
  get_child_key,
@@ -327,12 +328,14 @@ class SWARadixCache(BasePrefixCache):
327
328
  sliding_window_size: int,
328
329
  page_size: int,
329
330
  disable: bool = False,
331
+ is_eagle: bool = False,
330
332
  ):
331
333
  assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
332
334
  self.req_to_token_pool = req_to_token_pool
333
335
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
334
336
  self.page_size = page_size
335
337
  self.disable = disable
338
+ self.is_eagle = is_eagle
336
339
 
337
340
  if self.token_to_kv_pool_allocator:
338
341
  self.device = self.token_to_kv_pool_allocator.device
@@ -346,6 +349,11 @@ class SWARadixCache(BasePrefixCache):
346
349
  self.key_match_fn = partial(_key_match_paged, page_size=page_size)
347
350
  self.get_child_key_fn = partial(get_child_key, page_size=page_size)
348
351
 
352
+ if is_eagle:
353
+ self.key_convert_fn = _convert_to_bigram_key
354
+ else:
355
+ self.key_convert_fn = lambda key: key
356
+
349
357
  self.sliding_window_size = sliding_window_size
350
358
  self.reset()
351
359
 
@@ -376,6 +384,8 @@ class SWARadixCache(BasePrefixCache):
376
384
  The last node create a new child if the prefix is shorter
377
385
  than the last node's value.
378
386
  """
387
+ key.token_ids = self.key_convert_fn(key.token_ids)
388
+
379
389
  if self.disable or len(key) == 0:
380
390
  return MatchResult(
381
391
  device_indices=torch.empty(
@@ -406,42 +416,73 @@ class SWARadixCache(BasePrefixCache):
406
416
  if self.disable:
407
417
  return 0
408
418
 
419
+ key.token_ids = self.key_convert_fn(key.token_ids)
420
+
409
421
  if value is None:
410
422
  value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
423
+
424
+ if self.is_eagle:
425
+ # Make sure the value len equal to the EAGLE bigram key len
426
+ value = value[: len(key)]
427
+
411
428
  return self._insert_helper(self.root_node, key, value, prev_prefix_len)
412
429
 
413
- def cache_finished_req(self, req: Req) -> None:
430
+ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
414
431
  """Cache request when it finishes."""
432
+ all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
415
433
  if self.disable:
416
434
  kv_indices = self.req_to_token_pool.req_to_token[
417
- req.req_pool_idx,
418
- : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
435
+ req.req_pool_idx, :all_token_len
419
436
  ]
420
437
  self.token_to_kv_pool_allocator.free(kv_indices)
421
438
  self.req_to_token_pool.free(req.req_pool_idx)
422
439
  return
423
440
 
424
- token_ids = (req.origin_input_ids + req.output_ids)[:-1]
441
+ token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
442
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
443
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
444
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
425
445
  kv_indices = self.req_to_token_pool.req_to_token[
426
- req.req_pool_idx, : len(token_ids)
446
+ req.req_pool_idx, :all_token_len
427
447
  ]
428
448
 
429
449
  if self.page_size != 1:
430
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
431
- page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
432
- self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
450
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
451
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
452
+ dtype=torch.int64, copy=True
453
+ )
433
454
  else:
434
- page_aligned_len = len(kv_indices)
435
- page_aligned_kv_indices = kv_indices.clone()
455
+ page_aligned_len = actual_kv_len
456
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
457
+ if self.is_eagle:
458
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
459
+
460
+ page_aligned_token_len = (
461
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
462
+ )
463
+
464
+ old_prefix_len = len(req.prefix_indices)
465
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
466
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
467
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
468
+ old_prefix_len -= 1
436
469
 
437
470
  # Radix Cache takes one ref in memory pool
438
471
  # insert the token_ids and kv_indices into the radix tree
439
472
  # Note: the insert function already frees the overlapped kv_indices
440
- new_prefix_len = self.insert(
441
- RadixKey(token_ids[:page_aligned_len], req.extra_key),
442
- page_aligned_kv_indices,
443
- len(req.prefix_indices),
444
- )
473
+ if is_insert:
474
+ new_prefix_len = self.insert(
475
+ RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
476
+ page_aligned_kv_indices,
477
+ old_prefix_len,
478
+ )
479
+ else:
480
+ self.token_to_kv_pool_allocator.free(
481
+ kv_indices[old_prefix_len:page_aligned_len]
482
+ )
483
+
484
+ # free the unaligned tail
485
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
445
486
 
446
487
  # Remove req slot release the cache lock
447
488
  self.req_to_token_pool.free(req.req_pool_idx)
@@ -459,39 +500,58 @@ class SWARadixCache(BasePrefixCache):
459
500
  return
460
501
 
461
502
  token_ids = req.fill_ids
503
+ all_token_len = len(token_ids)
504
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
505
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
506
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
462
507
  kv_indices = self.req_to_token_pool.req_to_token[
463
- req.req_pool_idx, : len(token_ids)
508
+ req.req_pool_idx, :all_token_len
464
509
  ]
465
510
 
466
511
  if self.page_size != 1:
467
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
468
- page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
512
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
513
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
514
+ dtype=torch.int64, copy=True
515
+ )
469
516
  else:
470
- page_aligned_len = len(kv_indices)
471
- page_aligned_kv_indices = kv_indices.clone()
472
- page_aligned_token_ids = token_ids[:page_aligned_len]
517
+ page_aligned_len = actual_kv_len
518
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
519
+
520
+ # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
521
+ page_aligned_token_len = (
522
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
523
+ )
524
+ page_aligned_token_ids = token_ids[:page_aligned_token_len]
525
+
526
+ old_prefix_len = len(req.prefix_indices)
527
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
528
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
529
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
530
+ old_prefix_len -= 1
473
531
 
474
532
  # Radix Cache takes one ref in memory pool
475
533
  # Note: the insert function already frees the overlapped kv_indices
476
534
  new_prefix_len = self.insert(
477
535
  RadixKey(page_aligned_token_ids, req.extra_key),
478
536
  page_aligned_kv_indices,
479
- len(req.prefix_indices),
537
+ old_prefix_len,
480
538
  )
481
539
 
482
540
  # The prefix indices could be updated, reuse it
483
541
  new_indices, new_last_node, _, _ = self.match_prefix(
484
542
  RadixKey(page_aligned_token_ids, req.extra_key)
485
543
  )
486
- assert len(req.prefix_indices) <= len(
544
+ assert old_prefix_len <= len(
487
545
  new_indices
488
546
  ), f"{req.prefix_indices=}, {new_indices=}"
489
547
  assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
490
548
  self.req_to_token_pool.write(
491
- (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
492
- new_indices[len(req.prefix_indices) :],
549
+ (req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
550
+ new_indices[old_prefix_len:],
493
551
  )
494
552
 
553
+ req.last_matched_prefix_len = len(new_indices)
554
+
495
555
  self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
496
556
  swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
497
557
 
@@ -501,7 +561,13 @@ class SWARadixCache(BasePrefixCache):
501
561
  [new_indices, kv_indices[len(new_indices) :]]
502
562
  )
503
563
  else:
504
- req.prefix_indices = new_indices
564
+ if self.is_eagle:
565
+ # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
566
+ req.prefix_indices = torch.cat(
567
+ [new_indices, kv_indices[actual_kv_len:]]
568
+ )
569
+ else:
570
+ req.prefix_indices = new_indices
505
571
  req.last_node = new_last_node
506
572
  req.swa_uuid_for_lock = swa_uuid_for_lock
507
573
 
@@ -118,6 +118,7 @@ class SchedulerStats:
118
118
  num_running_reqs: int = 0
119
119
  num_used_tokens: int = 0
120
120
  token_usage: float = 0.0
121
+ pending_prealloc_token_usage: float = 0.0
121
122
  swa_token_usage: float = 0.0
122
123
  gen_throughput: float = 0.0
123
124
  num_queue_reqs: int = 0
@@ -127,6 +128,7 @@ class SchedulerStats:
127
128
 
128
129
  # Speculative decoding
129
130
  spec_accept_length: float = 0.0
131
+ spec_accept_rate: float = 0.0
130
132
 
131
133
  # Retract
132
134
  num_retracted_reqs: int = 0
@@ -148,6 +150,9 @@ class SchedulerStats:
148
150
  engine_startup_time: float = 0.0
149
151
  engine_load_weights_time: float = 0.0
150
152
 
153
+ # CUDA graph
154
+ is_cuda_graph: float = 0.0
155
+
151
156
 
152
157
  class SchedulerMetricsCollector:
153
158
 
@@ -176,6 +181,12 @@ class SchedulerMetricsCollector:
176
181
  labelnames=labels.keys(),
177
182
  multiprocess_mode="mostrecent",
178
183
  )
184
+ self.pending_prealloc_token_usage = Gauge(
185
+ name="sglang:pending_prealloc_token_usage",
186
+ documentation="The token usage for pending preallocated tokens (not preallocated yet).",
187
+ labelnames=labels.keys(),
188
+ multiprocess_mode="mostrecent",
189
+ )
179
190
  self.swa_token_usage = Gauge(
180
191
  name="sglang:swa_token_usage",
181
192
  documentation="The token usage for SWA layers.",
@@ -220,6 +231,12 @@ class SchedulerMetricsCollector:
220
231
  labelnames=labels.keys(),
221
232
  multiprocess_mode="mostrecent",
222
233
  )
234
+ self.spec_accept_rate = Gauge(
235
+ name="sglang:spec_accept_rate",
236
+ documentation="The average acceptance rate of speculative decoding (`accepted tokens / total draft tokens` in batch).",
237
+ labelnames=labels.keys(),
238
+ multiprocess_mode="mostrecent",
239
+ )
223
240
 
224
241
  # Retract
225
242
  self.num_retracted_reqs = Gauge(
@@ -485,6 +502,13 @@ class SchedulerMetricsCollector:
485
502
  labelnames=list(labels.keys()) + ["stage"],
486
503
  )
487
504
 
505
+ self.is_cuda_graph = Gauge(
506
+ name="sglang:is_cuda_graph",
507
+ documentation="Whether the batch is using CUDA graph.",
508
+ labelnames=labels.keys(),
509
+ multiprocess_mode="mostrecent",
510
+ )
511
+
488
512
  def _log_gauge(self, gauge, data: Union[int, float]) -> None:
489
513
  # Convenience function for logging to gauge.
490
514
  gauge.labels(**self.labels).set(data)
@@ -509,6 +533,9 @@ class SchedulerMetricsCollector:
509
533
  self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
510
534
  self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
511
535
  self._log_gauge(self.token_usage, stats.token_usage)
536
+ self._log_gauge(
537
+ self.pending_prealloc_token_usage, stats.pending_prealloc_token_usage
538
+ )
512
539
  self._log_gauge(self.swa_token_usage, stats.swa_token_usage)
513
540
  self._log_gauge(self.gen_throughput, stats.gen_throughput)
514
541
  self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
@@ -520,6 +547,7 @@ class SchedulerMetricsCollector:
520
547
 
521
548
  # Speculative decoding
522
549
  self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
550
+ self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate)
523
551
 
524
552
  # PD disaggregation
525
553
  self._log_gauge(
@@ -556,6 +584,9 @@ class SchedulerMetricsCollector:
556
584
  self.engine_load_weights_time, stats.engine_load_weights_time
557
585
  )
558
586
 
587
+ # CUDA graph
588
+ self._log_gauge(self.is_cuda_graph, stats.is_cuda_graph)
589
+
559
590
  self.last_log_time = time.perf_counter()
560
591
 
561
592
  def log_grammar_stats(self, grammar_stats) -> None:
@@ -18,7 +18,7 @@ Records the latency of some functions
18
18
  import asyncio
19
19
  import time
20
20
  from functools import wraps
21
- from typing import Any, Callable, List, Optional
21
+ from typing import Any, Callable, Optional
22
22
 
23
23
  from sglang.srt.metrics.utils import exponential_buckets
24
24
 
@@ -38,8 +38,11 @@ from sglang.srt.layers.dp_attention import (
38
38
  get_attention_tp_rank,
39
39
  get_attention_tp_size,
40
40
  set_dp_buffer_len,
41
+ set_is_extend_in_batch,
41
42
  )
42
43
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
+ from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
45
+ from sglang.srt.layers.moe.utils import get_deepep_mode, get_moe_a2a_backend
43
46
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
44
47
  from sglang.srt.model_executor.forward_batch_info import (
45
48
  CaptureHiddenMode,
@@ -53,7 +56,6 @@ from sglang.srt.utils import (
53
56
  empty_context,
54
57
  get_available_gpu_memory,
55
58
  get_bool_env_var,
56
- get_device_memory_capacity,
57
59
  is_hip,
58
60
  log_info_on_rank0,
59
61
  require_attn_tp_gather,
@@ -63,6 +65,13 @@ from sglang.srt.utils import (
63
65
  )
64
66
  from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
65
67
 
68
+ try:
69
+ from kt_kernel import AMXMoEWrapper
70
+
71
+ KTRANSFORMERS_AVAILABLE = True
72
+ except ImportError:
73
+ KTRANSFORMERS_AVAILABLE = False
74
+
66
75
  _is_hip = is_hip()
67
76
 
68
77
  logger = logging.getLogger(__name__)
@@ -241,9 +250,13 @@ class CudaGraphRunner:
241
250
  self.attn_tp_size = get_attention_tp_size()
242
251
  self.attn_tp_rank = get_attention_tp_rank()
243
252
 
253
+ self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
254
+
244
255
  # Batch sizes to capture
245
256
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
246
257
  log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
258
+ if KTRANSFORMERS_AVAILABLE:
259
+ AMXMoEWrapper.set_capture_batch_sizes(self.capture_bs)
247
260
  self.capture_forward_mode = ForwardMode.DECODE
248
261
  self.capture_hidden_mode = CaptureHiddenMode.NULL
249
262
  self.num_tokens_per_bs = 1
@@ -274,7 +287,6 @@ class CudaGraphRunner:
274
287
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
275
288
  )
276
289
 
277
- # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
278
290
  self.encoder_len_fill_value = 0
279
291
  self.seq_lens_cpu = torch.full(
280
292
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
@@ -637,6 +649,7 @@ class CudaGraphRunner:
637
649
  # Clean intermediate result cache for DP attention
638
650
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
639
651
  set_dp_buffer_len(global_dp_buffer_len, num_tokens)
652
+ set_is_extend_in_batch(False)
640
653
 
641
654
  kwargs = {}
642
655
  if (
@@ -655,6 +668,8 @@ class CudaGraphRunner:
655
668
  )
656
669
  return logits_output_or_pp_proxy_tensors
657
670
 
671
+ self.deepep_adapter.capture(is_extend_in_batch=False)
672
+
658
673
  for _ in range(2):
659
674
  self.device_module.synchronize()
660
675
  self.model_runner.tp_group.barrier()
@@ -678,8 +693,9 @@ class CudaGraphRunner:
678
693
  capture_hidden_mode_required_by_forward_batch = (
679
694
  forward_batch.capture_hidden_mode
680
695
  )
681
- capture_hidden_mode_required_by_spec_info = getattr(
682
- forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
696
+ capture_hidden_mode_required_by_spec_info = (
697
+ getattr(forward_batch.spec_info, "capture_hidden_mode", None)
698
+ or CaptureHiddenMode.NULL
683
699
  )
684
700
  capture_hidden_mode_required_for_returning_hidden_states = (
685
701
  CaptureHiddenMode.FULL
@@ -797,6 +813,8 @@ class CudaGraphRunner:
797
813
  skip_attn_backend_init: bool = False,
798
814
  pp_proxy_tensors: Optional[PPProxyTensors] = None,
799
815
  ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
816
+ self.deepep_adapter.replay()
817
+
800
818
  if not skip_attn_backend_init:
801
819
  self.replay_prepare(forward_batch, pp_proxy_tensors)
802
820
  else:
@@ -849,7 +867,7 @@ class CudaGraphRunner:
849
867
  )
850
868
 
851
869
  elif self.model_runner.spec_algorithm.is_ngram():
852
- from sglang.srt.speculative.ngram_utils import NgramVerifyInput
870
+ from sglang.srt.speculative.ngram_info import NgramVerifyInput
853
871
 
854
872
  spec_info = NgramVerifyInput(
855
873
  draft_token=None,
@@ -873,3 +891,23 @@ CUDA_GRAPH_CAPTURE_FAILED_MSG = (
873
891
  "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
874
892
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
875
893
  )
894
+
895
+
896
+ class DeepEPCudaGraphRunnerAdapter:
897
+ def __init__(self):
898
+ # Record DeepEP mode used during capture to ensure replay consistency
899
+ self._captured_deepep_mode = None
900
+
901
+ def capture(self, is_extend_in_batch: bool):
902
+ if not get_moe_a2a_backend().is_deepep():
903
+ return
904
+ self._captured_deepep_mode = get_deepep_mode().resolve(
905
+ is_extend_in_batch=is_extend_in_batch
906
+ )
907
+ DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
908
+
909
+ def replay(self):
910
+ if not get_moe_a2a_backend().is_deepep():
911
+ return
912
+ assert self._captured_deepep_mode is not None
913
+ DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)