sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,125 @@
1
+ #pragma once
2
+
3
+ #include <cstddef>
4
+ #include <iostream>
5
+ #include <limits>
6
+ #include <regex>
7
+ #include <sstream>
8
+ #include <stdexcept>
9
+ #include <string>
10
+ #include <vector>
11
+
12
+ namespace ngram {
13
+
14
+ struct Param {
15
+ bool enable;
16
+ bool enable_router_mode;
17
+ size_t min_bfs_breadth;
18
+ size_t max_bfs_breadth;
19
+ size_t min_match_window_size;
20
+ size_t max_match_window_size;
21
+ size_t branch_length;
22
+ size_t draft_token_num;
23
+ std::string match_type;
24
+
25
+ std::vector<size_t> batch_min_match_window_size;
26
+ std::vector<size_t> batch_draft_token_num;
27
+
28
+ size_t get_draft_token_num(size_t batch_size) const {
29
+ if (batch_size < batch_draft_token_num.size()) {
30
+ if (batch_draft_token_num[batch_size] !=
31
+ std::numeric_limits<decltype(batch_draft_token_num)::value_type>::max()) {
32
+ return batch_draft_token_num[batch_size];
33
+ }
34
+ }
35
+ return draft_token_num - 1;
36
+ }
37
+
38
+ size_t get_min_match_window_size(size_t batch_size) const {
39
+ if (batch_size < batch_min_match_window_size.size()) {
40
+ if (batch_min_match_window_size[batch_size] !=
41
+ std::numeric_limits<decltype(batch_min_match_window_size)::value_type>::max()) {
42
+ return batch_min_match_window_size[batch_size];
43
+ }
44
+ }
45
+ return min_match_window_size;
46
+ }
47
+
48
+ std::vector<size_t> parse(const std::string& value) {
49
+ // 0-1|10,2-3|20,
50
+ std::vector<size_t> result;
51
+ if (value.empty()) {
52
+ return result;
53
+ }
54
+ std::vector<size_t> mark;
55
+ std::regex comma_re(",");
56
+ std::sregex_token_iterator first{value.begin(), value.end(), comma_re, -1}, last;
57
+ for (auto p : std::vector<std::string>(first, last)) {
58
+ std::cerr << "seg " << p << std::endl;
59
+ }
60
+ for (const auto& seg : std::vector<std::string>(first, last)) {
61
+ std::regex pipe_re("\\|");
62
+ std::sregex_token_iterator seg_first{seg.begin(), seg.end(), pipe_re, -1}, seg_last;
63
+ std::vector<std::string> part(seg_first, seg_last);
64
+ for (auto p : part) {
65
+ std::cerr << "part " << p << std::endl;
66
+ }
67
+ if (part.size() != 2) {
68
+ throw std::runtime_error(
69
+ "failed to get config, invalid config: " + seg + ", part's size = " + std::to_string(part.size()));
70
+ }
71
+ std::regex endash_re("-");
72
+ std::sregex_token_iterator range_first{part[0].begin(), part[0].end(), endash_re, -1}, range_last;
73
+ std::vector<std::string> range(range_first, range_last);
74
+ if (range.size() != 2) {
75
+ throw std::runtime_error("failed to get range, invalid config: " + value);
76
+ }
77
+ size_t L = std::atoi(range[0].c_str());
78
+ size_t R = std::atoi(range[1].c_str());
79
+ if (L > R || R > 128) {
80
+ throw std::runtime_error("invalid range, config: " + value);
81
+ }
82
+ if (R >= result.size()) {
83
+ result.resize(R + 1, std::numeric_limits<decltype(result)::value_type>::max());
84
+ mark.resize(result.size(), false);
85
+ }
86
+ size_t config = std::atoi(part[1].c_str());
87
+ do {
88
+ if (mark[L]) {
89
+ throw std::runtime_error("repeated position " + std::to_string(L) + ", config : " + value);
90
+ }
91
+ mark[L] = true;
92
+ result[L] = config;
93
+ } while (++L <= R);
94
+ }
95
+ return result;
96
+ }
97
+
98
+ void resetBatchMinMatchWindowSize(const std::string& value) {
99
+ batch_min_match_window_size = parse(value);
100
+ }
101
+
102
+ void resetBatchReturnTokenNum(const std::string& value) {
103
+ batch_draft_token_num = parse(value);
104
+ }
105
+
106
+ std::string detail() {
107
+ std::stringstream ss;
108
+ ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode
109
+ << ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth
110
+ << ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size
111
+ << ", branch_length = " << branch_length << ", draft_token_num = " << draft_token_num
112
+ << ", match_type = " << match_type;
113
+ ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = ";
114
+ for (int i = 0; i < batch_min_match_window_size.size(); ++i) {
115
+ ss << i << "|" << batch_min_match_window_size[i] << ",";
116
+ }
117
+ ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = ";
118
+ for (int i = 0; i < batch_draft_token_num.size(); ++i) {
119
+ ss << i << "|" << batch_draft_token_num[i] << ",";
120
+ }
121
+ return ss.str();
122
+ }
123
+ };
124
+
125
+ } // namespace ngram
@@ -0,0 +1,71 @@
1
+ #pragma once
2
+
3
+ #include <condition_variable>
4
+ #include <queue>
5
+
6
+ namespace utils {
7
+
8
+ template <typename T>
9
+ class Queue {
10
+ public:
11
+ bool enqueue(T&& rhs) {
12
+ {
13
+ std::lock_guard<std::mutex> lock(mutex_);
14
+ if (closed_) {
15
+ return false;
16
+ }
17
+ queue_.emplace(std::move(rhs));
18
+ }
19
+ cv_.notify_one();
20
+ return true;
21
+ }
22
+
23
+ bool enqueue(const T& rhs) {
24
+ {
25
+ std::lock_guard<std::mutex> lock(mutex_);
26
+ if (closed_) {
27
+ return false;
28
+ }
29
+ queue_.emplace(rhs);
30
+ }
31
+ cv_.notify_one();
32
+ return true;
33
+ }
34
+
35
+ bool dequeue(T& rhs) {
36
+ std::unique_lock<std::mutex> lock(mutex_);
37
+ cv_.wait(lock, [this] { return queue_.size() || closed_; });
38
+ if (closed_) {
39
+ return false;
40
+ }
41
+ rhs = std::move(queue_.front());
42
+ queue_.pop();
43
+ return true;
44
+ }
45
+
46
+ size_t size() const {
47
+ std::lock_guard<std::mutex> lock(mutex_);
48
+ return queue_.size();
49
+ }
50
+
51
+ bool empty() const {
52
+ std::lock_guard<std::mutex> lock(mutex_);
53
+ return queue_.empty();
54
+ }
55
+
56
+ void close() {
57
+ {
58
+ std::lock_guard<std::mutex> lock(mutex_);
59
+ closed_ = true;
60
+ }
61
+ cv_.notify_all();
62
+ }
63
+
64
+ private:
65
+ std::queue<T> queue_;
66
+ mutable std::mutex mutex_;
67
+ std::condition_variable cv_;
68
+ bool closed_{false};
69
+ };
70
+
71
+ } // namespace utils
@@ -0,0 +1,226 @@
1
+ import logging
2
+
3
+ from sglang.srt.server_args import ServerArgs, get_global_server_args
4
+ from sglang.srt.utils.common import is_blackwell
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class DraftBackendFactory:
10
+ def __init__(
11
+ self,
12
+ server_args: ServerArgs,
13
+ draft_model_runner,
14
+ topk: int,
15
+ speculative_num_steps: int,
16
+ ):
17
+ self.server_args = server_args
18
+ self.draft_model_runner = draft_model_runner
19
+ self.topk = topk
20
+ self.speculative_num_steps = speculative_num_steps
21
+
22
+ def _create_backend(
23
+ self, backend_name: str, backend_map: dict, error_template: str
24
+ ):
25
+ backend_type = getattr(self.server_args, backend_name)
26
+ if backend_type is None:
27
+ backend_type = self.server_args.attention_backend
28
+
29
+ if backend_type not in backend_map:
30
+ raise ValueError(error_template.format(backend_type=backend_type))
31
+
32
+ return backend_map[backend_type]()
33
+
34
+ def create_decode_backend(self):
35
+ if self.speculative_num_steps == 1:
36
+ return None
37
+
38
+ backend_map = {
39
+ "flashinfer": self._create_flashinfer_decode_backend,
40
+ "triton": self._create_triton_decode_backend,
41
+ "aiter": self._create_aiter_decode_backend,
42
+ "fa3": self._create_fa3_decode_backend,
43
+ "hybrid_linear_attn": (
44
+ self._create_fa3_decode_backend
45
+ if not is_blackwell()
46
+ else self._create_triton_decode_backend
47
+ ),
48
+ "flashmla": self._create_flashmla_decode_backend,
49
+ "trtllm_mha": self._create_trtllm_mha_decode_backend,
50
+ "trtllm_mla": self._create_trtllm_mla_decode_backend,
51
+ "nsa": self._create_nsa_decode_backend,
52
+ }
53
+
54
+ return self._create_backend(
55
+ "decode_attention_backend",
56
+ backend_map,
57
+ "EAGLE is not supported in decode attention backend {backend_type}",
58
+ )
59
+
60
+ def create_draft_extend_backend(self):
61
+ backend_map = {
62
+ "flashinfer": self._create_flashinfer_prefill_backend,
63
+ "triton": self._create_triton_prefill_backend,
64
+ "aiter": self._create_aiter_prefill_backend,
65
+ "fa3": self._create_fa3_prefill_backend,
66
+ "hybrid_linear_attn": (
67
+ self._create_fa3_prefill_backend
68
+ if not is_blackwell()
69
+ else self._create_triton_prefill_backend
70
+ ),
71
+ "flashmla": self._create_flashmla_prefill_backend,
72
+ "trtllm_mha": self._create_trtllm_mha_prefill_backend,
73
+ "trtllm_mla": self._create_trtllm_mla_prefill_backend,
74
+ "nsa": self._create_nsa_prefill_backend,
75
+ }
76
+ backend_name = (
77
+ "decode_attention_backend"
78
+ if self.server_args.speculative_attention_mode == "decode"
79
+ else "prefill_attention_backend"
80
+ )
81
+ return self._create_backend(
82
+ backend_name,
83
+ backend_map,
84
+ "EAGLE is not supported in attention backend {backend_type}",
85
+ )
86
+
87
+ def _create_nsa_decode_backend(self):
88
+ from sglang.srt.layers.attention.nsa_backend import (
89
+ NativeSparseAttnMultiStepBackend,
90
+ )
91
+
92
+ return NativeSparseAttnMultiStepBackend(
93
+ self.draft_model_runner, self.topk, self.speculative_num_steps
94
+ )
95
+
96
+ def _create_nsa_prefill_backend(self):
97
+ from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
98
+
99
+ return NativeSparseAttnBackend(self.draft_model_runner, skip_prefill=False)
100
+
101
+ def _create_flashinfer_decode_backend(self):
102
+ if not get_global_server_args().use_mla_backend:
103
+ from sglang.srt.layers.attention.flashinfer_backend import (
104
+ FlashInferMultiStepDraftBackend,
105
+ )
106
+
107
+ return FlashInferMultiStepDraftBackend(
108
+ self.draft_model_runner, self.topk, self.speculative_num_steps
109
+ )
110
+ else:
111
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
112
+ FlashInferMLAMultiStepDraftBackend,
113
+ )
114
+
115
+ return FlashInferMLAMultiStepDraftBackend(
116
+ self.draft_model_runner, self.topk, self.speculative_num_steps
117
+ )
118
+
119
+ def _create_triton_decode_backend(self):
120
+ from sglang.srt.layers.attention.triton_backend import (
121
+ TritonMultiStepDraftBackend,
122
+ )
123
+
124
+ return TritonMultiStepDraftBackend(
125
+ self.draft_model_runner, self.topk, self.speculative_num_steps
126
+ )
127
+
128
+ def _create_aiter_decode_backend(self):
129
+ from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
130
+
131
+ return AiterMultiStepDraftBackend(
132
+ self.draft_model_runner, self.topk, self.speculative_num_steps
133
+ )
134
+
135
+ def _create_fa3_decode_backend(self):
136
+ from sglang.srt.layers.attention.flashattention_backend import (
137
+ FlashAttentionMultiStepBackend,
138
+ )
139
+
140
+ return FlashAttentionMultiStepBackend(
141
+ self.draft_model_runner, self.topk, self.speculative_num_steps
142
+ )
143
+
144
+ def _create_flashmla_decode_backend(self):
145
+ from sglang.srt.layers.attention.flashmla_backend import (
146
+ FlashMLAMultiStepDraftBackend,
147
+ )
148
+
149
+ return FlashMLAMultiStepDraftBackend(
150
+ self.draft_model_runner, self.topk, self.speculative_num_steps
151
+ )
152
+
153
+ def _create_trtllm_mha_decode_backend(self):
154
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
155
+ TRTLLMHAAttnMultiStepDraftBackend,
156
+ )
157
+
158
+ return TRTLLMHAAttnMultiStepDraftBackend(
159
+ self.draft_model_runner, self.topk, self.speculative_num_steps
160
+ )
161
+
162
+ def _create_trtllm_mla_decode_backend(self):
163
+ if not get_global_server_args().use_mla_backend:
164
+ raise ValueError(
165
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
166
+ )
167
+
168
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
169
+ TRTLLMMLAMultiStepDraftBackend,
170
+ )
171
+
172
+ return TRTLLMMLAMultiStepDraftBackend(
173
+ self.draft_model_runner, self.topk, self.speculative_num_steps
174
+ )
175
+
176
+ def _create_flashinfer_prefill_backend(self):
177
+ if not get_global_server_args().use_mla_backend:
178
+ from sglang.srt.layers.attention.flashinfer_backend import (
179
+ FlashInferAttnBackend,
180
+ )
181
+
182
+ return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
183
+ else:
184
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
185
+ FlashInferMLAAttnBackend,
186
+ )
187
+
188
+ return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
189
+
190
+ def _create_triton_prefill_backend(self):
191
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
192
+
193
+ return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
194
+
195
+ def _create_aiter_prefill_backend(self):
196
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
197
+
198
+ return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
199
+
200
+ def _create_fa3_prefill_backend(self):
201
+ from sglang.srt.layers.attention.flashattention_backend import (
202
+ FlashAttentionBackend,
203
+ )
204
+
205
+ return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
206
+
207
+ def _create_trtllm_mha_prefill_backend(self):
208
+ from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
209
+
210
+ return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
211
+
212
+ def _create_trtllm_mla_prefill_backend(self):
213
+ if not get_global_server_args().use_mla_backend:
214
+ raise ValueError(
215
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
216
+ )
217
+
218
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
219
+
220
+ return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
221
+
222
+ def _create_flashmla_prefill_backend(self):
223
+ logger.warning(
224
+ "flashmla prefill backend is not yet supported for draft extend."
225
+ )
226
+ return None
@@ -9,10 +9,12 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
12
+ DeepEPCudaGraphRunnerAdapter,
12
13
  get_batch_sizes_to_capture,
13
14
  get_global_graph_memory_pool,
14
15
  model_capture_mode,
15
16
  set_global_graph_memory_pool,
17
+ set_is_extend_in_batch,
16
18
  set_torch_compile_config,
17
19
  )
18
20
  from sglang.srt.model_executor.forward_batch_info import (
@@ -20,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import (
20
22
  ForwardBatch,
21
23
  ForwardMode,
22
24
  )
23
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
25
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
24
26
  from sglang.srt.utils import (
25
27
  require_attn_tp_gather,
26
28
  require_gathered_buffer,
@@ -40,8 +42,11 @@ class EAGLEDraftCudaGraphRunner:
40
42
  def __init__(self, eagle_worker: EAGLEWorker):
41
43
  # Parse args
42
44
  self.eagle_worker = eagle_worker
43
- self.model_runner = model_runner = eagle_worker.model_runner
44
- self.model_runner: EAGLEWorker
45
+ if not hasattr(eagle_worker, "model_runner"):
46
+ # V2: EagleDraftWorker
47
+ self.model_runner = model_runner = eagle_worker.draft_runner
48
+ else:
49
+ self.model_runner = model_runner = eagle_worker.model_runner
45
50
  self.graphs = {}
46
51
  self.output_buffers = {}
47
52
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
@@ -58,6 +63,7 @@ class EAGLEDraftCudaGraphRunner:
58
63
  self.enable_profile_cuda_graph = (
59
64
  model_runner.server_args.enable_profile_cuda_graph
60
65
  )
66
+ self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
61
67
  server_args = model_runner.server_args
62
68
 
63
69
  # Batch sizes to capture
@@ -76,6 +82,7 @@ class EAGLEDraftCudaGraphRunner:
76
82
  self.seq_lens_cpu = torch.full(
77
83
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
78
84
  )
85
+ self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs
79
86
 
80
87
  if self.enable_torch_compile:
81
88
  set_torch_compile_config()
@@ -87,6 +94,7 @@ class EAGLEDraftCudaGraphRunner:
87
94
  self.seq_lens = torch.full(
88
95
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
89
96
  )
97
+ self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
90
98
  self.out_cache_loc = torch.zeros(
91
99
  (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
92
100
  )
@@ -160,6 +168,9 @@ class EAGLEDraftCudaGraphRunner:
160
168
  # Graph inputs
161
169
  req_pool_indices = self.req_pool_indices[:num_seqs]
162
170
  seq_lens = self.seq_lens[:num_seqs]
171
+ seq_lens_cpu = self.seq_lens_cpu[:num_seqs]
172
+ extend_seq_lens = self.extend_seq_lens[:num_seqs]
173
+ extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs]
163
174
  out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
164
175
  positions = self.positions[:num_tokens]
165
176
  mrope_positions = self.mrope_positions[:, :num_tokens]
@@ -222,6 +233,9 @@ class EAGLEDraftCudaGraphRunner:
222
233
  input_ids=None,
223
234
  req_pool_indices=req_pool_indices,
224
235
  seq_lens=seq_lens,
236
+ seq_lens_cpu=seq_lens_cpu,
237
+ extend_seq_lens=extend_seq_lens,
238
+ extend_seq_lens_cpu=extend_seq_lens_cpu,
225
239
  req_to_token_pool=self.model_runner.req_to_token_pool,
226
240
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
227
241
  out_cache_loc=out_cache_loc,
@@ -250,6 +264,7 @@ class EAGLEDraftCudaGraphRunner:
250
264
  # Clean intermediate result cache for DP attention
251
265
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
252
266
  set_dp_buffer_len(global_dp_buffer_len, num_tokens)
267
+ set_is_extend_in_batch(False)
253
268
 
254
269
  # Backup two fields, which will be modified in-place in `draft_forward`.
255
270
  output_cache_loc_backup = forward_batch.out_cache_loc
@@ -261,6 +276,8 @@ class EAGLEDraftCudaGraphRunner:
261
276
  forward_batch.spec_info.hidden_states = hidden_states_backup
262
277
  return ret
263
278
 
279
+ self.deepep_adapter.capture(is_extend_in_batch=False)
280
+
264
281
  for _ in range(2):
265
282
  torch.cuda.synchronize()
266
283
  self.model_runner.tp_group.barrier()
@@ -276,14 +293,14 @@ class EAGLEDraftCudaGraphRunner:
276
293
  return graph, out
277
294
 
278
295
  def _postprocess_output_to_raw_bs(self, out, raw_bs):
279
- score_list, token_list, parents_list = out
280
- score_list = [x[:raw_bs] for x in score_list]
281
- token_list = [x[:raw_bs] for x in token_list]
282
- parents_list = [x[:raw_bs] for x in parents_list]
283
- return (score_list, token_list, parents_list)
296
+ # Keep the variables name for readability
297
+ parent_list, top_scores_index, draft_tokens = (t[:raw_bs] for t in out)
298
+ return parent_list, top_scores_index, draft_tokens
284
299
 
285
300
  def replay(self, forward_batch: ForwardBatch):
286
301
  assert forward_batch.out_cache_loc is not None
302
+ self.deepep_adapter.replay()
303
+
287
304
  raw_bs = forward_batch.batch_size
288
305
  raw_num_token = raw_bs * self.num_tokens_per_bs
289
306
 
@@ -302,6 +319,7 @@ class EAGLEDraftCudaGraphRunner:
302
319
  if bs != raw_bs:
303
320
  self.seq_lens.fill_(self.seq_len_fill_value)
304
321
  self.out_cache_loc.zero_()
322
+ self.positions.zero_()
305
323
 
306
324
  num_tokens = bs * self.num_tokens_per_bs
307
325
 
@@ -9,11 +9,13 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
12
+ DeepEPCudaGraphRunnerAdapter,
12
13
  LogitsProcessorOutput,
13
14
  get_batch_sizes_to_capture,
14
15
  get_global_graph_memory_pool,
15
16
  model_capture_mode,
16
17
  set_global_graph_memory_pool,
18
+ set_is_extend_in_batch,
17
19
  set_torch_compile_config,
18
20
  )
19
21
  from sglang.srt.model_executor.forward_batch_info import (
@@ -21,7 +23,8 @@ from sglang.srt.model_executor.forward_batch_info import (
21
23
  ForwardBatch,
22
24
  ForwardMode,
23
25
  )
24
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
26
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
27
+ from sglang.srt.speculative.spec_utils import fast_topk
25
28
  from sglang.srt.utils import (
26
29
  require_attn_tp_gather,
27
30
  require_gathered_buffer,
@@ -37,7 +40,12 @@ class EAGLEDraftExtendCudaGraphRunner:
37
40
  def __init__(self, eagle_worker: EAGLEWorker):
38
41
  # Parse args
39
42
  self.eagle_worker = eagle_worker
40
- self.model_runner = model_runner = eagle_worker.model_runner
43
+ if not hasattr(eagle_worker, "model_runner"):
44
+ # V2: EagleDraftWorker
45
+ self.model_runner = model_runner = eagle_worker.draft_runner
46
+ else:
47
+ self.model_runner = model_runner = eagle_worker.model_runner
48
+
41
49
  self.graphs = {}
42
50
  self.output_buffers = {}
43
51
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
@@ -55,6 +63,7 @@ class EAGLEDraftExtendCudaGraphRunner:
55
63
  )
56
64
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
57
65
  self.padded_static_len = -1
66
+ self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
58
67
 
59
68
  # Attention backend
60
69
  self.num_tokens_per_bs = self.speculative_num_steps + 1
@@ -70,6 +79,7 @@ class EAGLEDraftExtendCudaGraphRunner:
70
79
  self.seq_lens_cpu = torch.full(
71
80
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
72
81
  )
82
+ self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs
73
83
 
74
84
  if self.enable_torch_compile:
75
85
  set_torch_compile_config()
@@ -188,7 +198,9 @@ class EAGLEDraftExtendCudaGraphRunner:
188
198
  input_ids = self.input_ids[:num_tokens]
189
199
  req_pool_indices = self.req_pool_indices[:bs]
190
200
  seq_lens = self.seq_lens[:bs]
201
+ seq_lens_cpu = self.seq_lens_cpu[:bs]
191
202
  extend_seq_lens = self.extend_seq_lens[:bs]
203
+ extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs]
192
204
  accept_length = self.accept_length[:bs]
193
205
  out_cache_loc = self.out_cache_loc[:num_tokens]
194
206
  positions = self.positions[:num_tokens]
@@ -237,6 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
237
249
  )
238
250
  spec_info.positions = None
239
251
 
252
+ self.deepep_adapter.capture(is_extend_in_batch=True)
253
+
240
254
  # Forward batch
241
255
  forward_batch = ForwardBatch(
242
256
  forward_mode=ForwardMode.DRAFT_EXTEND,
@@ -244,6 +258,7 @@ class EAGLEDraftExtendCudaGraphRunner:
244
258
  input_ids=input_ids,
245
259
  req_pool_indices=req_pool_indices,
246
260
  seq_lens=seq_lens,
261
+ seq_lens_cpu=seq_lens_cpu,
247
262
  next_token_logits_buffer=next_token_logits_buffer,
248
263
  req_to_token_pool=self.model_runner.req_to_token_pool,
249
264
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
@@ -261,6 +276,7 @@ class EAGLEDraftExtendCudaGraphRunner:
261
276
  capture_hidden_mode=CaptureHiddenMode.LAST,
262
277
  attn_backend=self.eagle_worker.draft_extend_attn_backend,
263
278
  extend_seq_lens=extend_seq_lens,
279
+ extend_seq_lens_cpu=extend_seq_lens_cpu,
264
280
  padded_static_len=self.padded_static_len,
265
281
  )
266
282
 
@@ -279,12 +295,13 @@ class EAGLEDraftExtendCudaGraphRunner:
279
295
  # Clean intermediate result cache for DP attention
280
296
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
281
297
  set_dp_buffer_len(global_dp_buffer_len, num_tokens)
298
+ set_is_extend_in_batch(False)
282
299
 
283
300
  # Backup two fields, which will be modified in-place in `draft_forward`.
284
301
  output_cache_loc_backup = forward_batch.out_cache_loc
285
302
  hidden_states_backup = forward_batch.spec_info.hidden_states
286
303
 
287
- ret = self.eagle_worker.draft_model_runner.model.forward(
304
+ ret = self.model_runner.model.forward(
288
305
  forward_batch.input_ids,
289
306
  forward_batch.positions,
290
307
  forward_batch,
@@ -312,6 +329,8 @@ class EAGLEDraftExtendCudaGraphRunner:
312
329
 
313
330
  def replay(self, forward_batch: ForwardBatch):
314
331
  assert forward_batch.out_cache_loc is not None
332
+ self.deepep_adapter.replay()
333
+
315
334
  # batch_size and num_seqs can be different in case there are finished examples
316
335
  # in the batch, which will not be counted as num_seqs
317
336
  raw_bs = forward_batch.batch_size
@@ -331,6 +350,7 @@ class EAGLEDraftExtendCudaGraphRunner:
331
350
  if bs * self.num_tokens_per_bs != num_tokens:
332
351
  self.seq_lens.fill_(self.seq_len_fill_value)
333
352
  self.out_cache_loc.zero_()
353
+ self.positions.zero_()
334
354
  self.accept_length.fill_(1)
335
355
  self.extend_seq_lens.fill_(1)
336
356
 
@@ -360,6 +380,9 @@ class EAGLEDraftExtendCudaGraphRunner:
360
380
  self.seq_lens_cpu.fill_(self.seq_len_fill_value)
361
381
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
362
382
 
383
+ if forward_batch.extend_seq_lens_cpu is not None:
384
+ self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu
385
+
363
386
  if bs != raw_bs:
364
387
  forward_batch.spec_info.positions = self.positions[:num_tokens]
365
388
  forward_batch.spec_info.accept_length = self.accept_length[:bs]