sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ Handles merging of YAML configuration files with command-line arguments.
5
5
 
6
6
  import logging
7
7
  from pathlib import Path
8
- from typing import Any, Dict, List, Union
8
+ from typing import Any, Dict, List
9
9
 
10
10
  import yaml
11
11
 
@@ -1,17 +1,32 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from __future__ import annotations
16
+
1
17
  from dataclasses import dataclass
2
18
  from typing import TYPE_CHECKING, Any, Callable, Optional
3
19
 
4
20
  import torch
5
21
 
22
+ from sglang.srt.layers import deep_gemm_wrapper
6
23
  from sglang.srt.layers.moe import get_moe_runner_backend
24
+ from sglang.srt.layers.moe.topk import TopKOutput
7
25
  from sglang.srt.layers.moe.utils import is_sbo_enabled
8
- from sglang.srt.layers.quantization import deep_gemm_wrapper
9
- from sglang.srt.managers.schedule_batch import global_server_args_dict
10
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
11
26
  from sglang.srt.utils import get_int_env_var
12
27
 
13
28
  if TYPE_CHECKING:
14
- from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
29
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
15
30
 
16
31
 
17
32
  class SboFlags:
@@ -43,7 +58,7 @@ class CombineOverlapArgs:
43
58
  wait_event: torch.cuda.Event
44
59
  num_sms: int
45
60
  signal: Optional[torch.Tensor] = None
46
- threshold: int = -1
61
+ threshold: int = 0
47
62
 
48
63
 
49
64
  @dataclass
@@ -55,57 +70,47 @@ class DownGemmOverlapArgs:
55
70
 
56
71
  def execute_sbo(
57
72
  forward_shared_experts: Callable[[], Any],
58
- experts: "DeepEPMoE",
73
+ experts: FusedMoE,
59
74
  hidden_states: torch.Tensor,
60
- topk_idx: torch.Tensor,
61
- topk_weights: torch.Tensor,
62
- forward_batch: ForwardBatch,
63
- alt_stream: Optional = None,
75
+ topk_output: TopKOutput,
76
+ alt_stream: Optional[torch.cuda.Stream] = None,
77
+ disable_sbo: bool = False,
64
78
  ):
65
- shared_output = None
66
79
 
67
- dispatch_output = experts.dispatch(
68
- hidden_states, topk_idx, topk_weights, forward_batch
80
+ dispatch_output = experts.dispatcher.dispatch(
81
+ hidden_states=hidden_states, topk_output=topk_output
69
82
  )
70
83
 
71
84
  combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
72
- _compute_overlap_args(dispatch_output, alt_stream)
85
+ _compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
73
86
  )
74
87
 
75
- hidden_states = experts.moe_impl(
88
+ combine_input = experts.run_moe_core(
76
89
  dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
77
90
  )
78
91
  if (e := meta_overlap_args.get("record_event_after_down")) is not None:
79
92
  e.record()
80
93
 
81
- if SboFlags.enable_combine_shared_two_stream_overlap():
94
+ if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap():
82
95
  # TODO reduce sm for non-deepgemm
83
96
  with deep_gemm_wrapper.configure_deep_gemm_num_sms(
84
97
  meta_overlap_args["compute_num_sms"]
85
98
  ):
86
- shared_output = forward_shared_experts()
87
-
88
- hidden_states = experts.combine(
89
- hidden_states,
90
- dispatch_output.topk_idx,
91
- dispatch_output.topk_weights,
92
- forward_batch,
93
- overlap_args=combine_overlap_args,
94
- )
99
+ forward_shared_experts()
100
+
101
+ hidden_states = experts.dispatcher.combine(combine_input=combine_input)
95
102
 
96
- return hidden_states, shared_output
103
+ return hidden_states
97
104
 
98
105
 
99
- def _compute_overlap_args(dispatch_output, alt_stream):
100
- if not (
106
+ def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
107
+ if disable_sbo or not (
101
108
  SboFlags.enable_combine_down_gemm_two_stream_overlap()
102
109
  or SboFlags.enable_combine_shared_two_stream_overlap()
103
110
  ):
104
111
  return None, None, {}
105
112
 
106
- hidden_states = dispatch_output.hidden_states_fp8
107
- if isinstance(hidden_states, tuple):
108
- hidden_states = hidden_states[0]
113
+ hidden_states = dispatch_output.hidden_states
109
114
 
110
115
  num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
111
116
 
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from sglang.srt.managers.tp_worker import TpModelWorker
8
+
9
+
10
+ class BaseDraftWorker(ABC):
11
+ @abstractmethod
12
+ def draft():
13
+ pass
14
+
15
+ @abstractmethod
16
+ def draft_extend():
17
+ pass
18
+
19
+
20
+ class BaseSpecWorker(ABC):
21
+ @property
22
+ @abstractmethod
23
+ def target_worker(self) -> TpModelWorker:
24
+ pass
25
+
26
+ @property
27
+ @abstractmethod
28
+ def draft_worker(self) -> BaseDraftWorker:
29
+ pass
30
+
31
+ @abstractmethod
32
+ def clear_cache_pool(self):
33
+ # TODO: move this abstract method to BaseTpWorker and call through self.model_runner
34
+ pass
@@ -0,0 +1,226 @@
1
+ import logging
2
+
3
+ from sglang.srt.server_args import ServerArgs, get_global_server_args
4
+ from sglang.srt.utils.common import is_blackwell
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class DraftBackendFactory:
10
+ def __init__(
11
+ self,
12
+ server_args: ServerArgs,
13
+ draft_model_runner,
14
+ topk: int,
15
+ speculative_num_steps: int,
16
+ ):
17
+ self.server_args = server_args
18
+ self.draft_model_runner = draft_model_runner
19
+ self.topk = topk
20
+ self.speculative_num_steps = speculative_num_steps
21
+
22
+ def _create_backend(
23
+ self, backend_name: str, backend_map: dict, error_template: str
24
+ ):
25
+ backend_type = getattr(self.server_args, backend_name)
26
+ if backend_type is None:
27
+ backend_type = self.server_args.attention_backend
28
+
29
+ if backend_type not in backend_map:
30
+ raise ValueError(error_template.format(backend_type=backend_type))
31
+
32
+ return backend_map[backend_type]()
33
+
34
+ def create_decode_backend(self):
35
+ if self.speculative_num_steps == 1:
36
+ return None
37
+
38
+ backend_map = {
39
+ "flashinfer": self._create_flashinfer_decode_backend,
40
+ "triton": self._create_triton_decode_backend,
41
+ "aiter": self._create_aiter_decode_backend,
42
+ "fa3": self._create_fa3_decode_backend,
43
+ "hybrid_linear_attn": (
44
+ self._create_fa3_decode_backend
45
+ if not is_blackwell()
46
+ else self._create_triton_decode_backend
47
+ ),
48
+ "flashmla": self._create_flashmla_decode_backend,
49
+ "trtllm_mha": self._create_trtllm_mha_decode_backend,
50
+ "trtllm_mla": self._create_trtllm_mla_decode_backend,
51
+ "nsa": self._create_nsa_decode_backend,
52
+ }
53
+
54
+ return self._create_backend(
55
+ "decode_attention_backend",
56
+ backend_map,
57
+ "EAGLE is not supported in decode attention backend {backend_type}",
58
+ )
59
+
60
+ def create_draft_extend_backend(self):
61
+ backend_map = {
62
+ "flashinfer": self._create_flashinfer_prefill_backend,
63
+ "triton": self._create_triton_prefill_backend,
64
+ "aiter": self._create_aiter_prefill_backend,
65
+ "fa3": self._create_fa3_prefill_backend,
66
+ "hybrid_linear_attn": (
67
+ self._create_fa3_prefill_backend
68
+ if not is_blackwell()
69
+ else self._create_triton_prefill_backend
70
+ ),
71
+ "flashmla": self._create_flashmla_prefill_backend,
72
+ "trtllm_mha": self._create_trtllm_mha_prefill_backend,
73
+ "trtllm_mla": self._create_trtllm_mla_prefill_backend,
74
+ "nsa": self._create_nsa_prefill_backend,
75
+ }
76
+ backend_name = (
77
+ "decode_attention_backend"
78
+ if self.server_args.speculative_attention_mode == "decode"
79
+ else "prefill_attention_backend"
80
+ )
81
+ return self._create_backend(
82
+ backend_name,
83
+ backend_map,
84
+ "EAGLE is not supported in attention backend {backend_type}",
85
+ )
86
+
87
+ def _create_nsa_decode_backend(self):
88
+ from sglang.srt.layers.attention.nsa_backend import (
89
+ NativeSparseAttnMultiStepBackend,
90
+ )
91
+
92
+ return NativeSparseAttnMultiStepBackend(
93
+ self.draft_model_runner, self.topk, self.speculative_num_steps
94
+ )
95
+
96
+ def _create_nsa_prefill_backend(self):
97
+ from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
98
+
99
+ return NativeSparseAttnBackend(self.draft_model_runner, skip_prefill=False)
100
+
101
+ def _create_flashinfer_decode_backend(self):
102
+ if not get_global_server_args().use_mla_backend:
103
+ from sglang.srt.layers.attention.flashinfer_backend import (
104
+ FlashInferMultiStepDraftBackend,
105
+ )
106
+
107
+ return FlashInferMultiStepDraftBackend(
108
+ self.draft_model_runner, self.topk, self.speculative_num_steps
109
+ )
110
+ else:
111
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
112
+ FlashInferMLAMultiStepDraftBackend,
113
+ )
114
+
115
+ return FlashInferMLAMultiStepDraftBackend(
116
+ self.draft_model_runner, self.topk, self.speculative_num_steps
117
+ )
118
+
119
+ def _create_triton_decode_backend(self):
120
+ from sglang.srt.layers.attention.triton_backend import (
121
+ TritonMultiStepDraftBackend,
122
+ )
123
+
124
+ return TritonMultiStepDraftBackend(
125
+ self.draft_model_runner, self.topk, self.speculative_num_steps
126
+ )
127
+
128
+ def _create_aiter_decode_backend(self):
129
+ from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
130
+
131
+ return AiterMultiStepDraftBackend(
132
+ self.draft_model_runner, self.topk, self.speculative_num_steps
133
+ )
134
+
135
+ def _create_fa3_decode_backend(self):
136
+ from sglang.srt.layers.attention.flashattention_backend import (
137
+ FlashAttentionMultiStepBackend,
138
+ )
139
+
140
+ return FlashAttentionMultiStepBackend(
141
+ self.draft_model_runner, self.topk, self.speculative_num_steps
142
+ )
143
+
144
+ def _create_flashmla_decode_backend(self):
145
+ from sglang.srt.layers.attention.flashmla_backend import (
146
+ FlashMLAMultiStepDraftBackend,
147
+ )
148
+
149
+ return FlashMLAMultiStepDraftBackend(
150
+ self.draft_model_runner, self.topk, self.speculative_num_steps
151
+ )
152
+
153
+ def _create_trtllm_mha_decode_backend(self):
154
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
155
+ TRTLLMHAAttnMultiStepDraftBackend,
156
+ )
157
+
158
+ return TRTLLMHAAttnMultiStepDraftBackend(
159
+ self.draft_model_runner, self.topk, self.speculative_num_steps
160
+ )
161
+
162
+ def _create_trtllm_mla_decode_backend(self):
163
+ if not get_global_server_args().use_mla_backend:
164
+ raise ValueError(
165
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
166
+ )
167
+
168
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
169
+ TRTLLMMLAMultiStepDraftBackend,
170
+ )
171
+
172
+ return TRTLLMMLAMultiStepDraftBackend(
173
+ self.draft_model_runner, self.topk, self.speculative_num_steps
174
+ )
175
+
176
+ def _create_flashinfer_prefill_backend(self):
177
+ if not get_global_server_args().use_mla_backend:
178
+ from sglang.srt.layers.attention.flashinfer_backend import (
179
+ FlashInferAttnBackend,
180
+ )
181
+
182
+ return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
183
+ else:
184
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
185
+ FlashInferMLAAttnBackend,
186
+ )
187
+
188
+ return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
189
+
190
+ def _create_triton_prefill_backend(self):
191
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
192
+
193
+ return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
194
+
195
+ def _create_aiter_prefill_backend(self):
196
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
197
+
198
+ return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
199
+
200
+ def _create_fa3_prefill_backend(self):
201
+ from sglang.srt.layers.attention.flashattention_backend import (
202
+ FlashAttentionBackend,
203
+ )
204
+
205
+ return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
206
+
207
+ def _create_trtllm_mha_prefill_backend(self):
208
+ from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
209
+
210
+ return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
211
+
212
+ def _create_trtllm_mla_prefill_backend(self):
213
+ if not get_global_server_args().use_mla_backend:
214
+ raise ValueError(
215
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
216
+ )
217
+
218
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
219
+
220
+ return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
221
+
222
+ def _create_flashmla_prefill_backend(self):
223
+ logger.warning(
224
+ "flashmla prefill backend is not yet supported for draft extend."
225
+ )
226
+ return None
@@ -9,10 +9,12 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
12
+ DeepEPCudaGraphRunnerAdapter,
12
13
  get_batch_sizes_to_capture,
13
14
  get_global_graph_memory_pool,
14
15
  model_capture_mode,
15
16
  set_global_graph_memory_pool,
17
+ set_is_extend_in_batch,
16
18
  set_torch_compile_config,
17
19
  )
18
20
  from sglang.srt.model_executor.forward_batch_info import (
@@ -40,8 +42,11 @@ class EAGLEDraftCudaGraphRunner:
40
42
  def __init__(self, eagle_worker: EAGLEWorker):
41
43
  # Parse args
42
44
  self.eagle_worker = eagle_worker
43
- self.model_runner = model_runner = eagle_worker.model_runner
44
- self.model_runner: EAGLEWorker
45
+ if not hasattr(eagle_worker, "model_runner"):
46
+ # V2: EagleDraftWorker
47
+ self.model_runner = model_runner = eagle_worker.draft_runner
48
+ else:
49
+ self.model_runner = model_runner = eagle_worker.model_runner
45
50
  self.graphs = {}
46
51
  self.output_buffers = {}
47
52
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
@@ -58,6 +63,7 @@ class EAGLEDraftCudaGraphRunner:
58
63
  self.enable_profile_cuda_graph = (
59
64
  model_runner.server_args.enable_profile_cuda_graph
60
65
  )
66
+ self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
61
67
  server_args = model_runner.server_args
62
68
 
63
69
  # Batch sizes to capture
@@ -76,6 +82,7 @@ class EAGLEDraftCudaGraphRunner:
76
82
  self.seq_lens_cpu = torch.full(
77
83
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
78
84
  )
85
+ self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs
79
86
 
80
87
  if self.enable_torch_compile:
81
88
  set_torch_compile_config()
@@ -87,6 +94,7 @@ class EAGLEDraftCudaGraphRunner:
87
94
  self.seq_lens = torch.full(
88
95
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
89
96
  )
97
+ self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
90
98
  self.out_cache_loc = torch.zeros(
91
99
  (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
92
100
  )
@@ -160,6 +168,9 @@ class EAGLEDraftCudaGraphRunner:
160
168
  # Graph inputs
161
169
  req_pool_indices = self.req_pool_indices[:num_seqs]
162
170
  seq_lens = self.seq_lens[:num_seqs]
171
+ seq_lens_cpu = self.seq_lens_cpu[:num_seqs]
172
+ extend_seq_lens = self.extend_seq_lens[:num_seqs]
173
+ extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs]
163
174
  out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
164
175
  positions = self.positions[:num_tokens]
165
176
  mrope_positions = self.mrope_positions[:, :num_tokens]
@@ -222,6 +233,9 @@ class EAGLEDraftCudaGraphRunner:
222
233
  input_ids=None,
223
234
  req_pool_indices=req_pool_indices,
224
235
  seq_lens=seq_lens,
236
+ seq_lens_cpu=seq_lens_cpu,
237
+ extend_seq_lens=extend_seq_lens,
238
+ extend_seq_lens_cpu=extend_seq_lens_cpu,
225
239
  req_to_token_pool=self.model_runner.req_to_token_pool,
226
240
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
227
241
  out_cache_loc=out_cache_loc,
@@ -250,6 +264,7 @@ class EAGLEDraftCudaGraphRunner:
250
264
  # Clean intermediate result cache for DP attention
251
265
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
252
266
  set_dp_buffer_len(global_dp_buffer_len, num_tokens)
267
+ set_is_extend_in_batch(False)
253
268
 
254
269
  # Backup two fields, which will be modified in-place in `draft_forward`.
255
270
  output_cache_loc_backup = forward_batch.out_cache_loc
@@ -261,6 +276,8 @@ class EAGLEDraftCudaGraphRunner:
261
276
  forward_batch.spec_info.hidden_states = hidden_states_backup
262
277
  return ret
263
278
 
279
+ self.deepep_adapter.capture(is_extend_in_batch=False)
280
+
264
281
  for _ in range(2):
265
282
  torch.cuda.synchronize()
266
283
  self.model_runner.tp_group.barrier()
@@ -276,14 +293,14 @@ class EAGLEDraftCudaGraphRunner:
276
293
  return graph, out
277
294
 
278
295
  def _postprocess_output_to_raw_bs(self, out, raw_bs):
279
- score_list, token_list, parents_list = out
280
- score_list = [x[:raw_bs] for x in score_list]
281
- token_list = [x[:raw_bs] for x in token_list]
282
- parents_list = [x[:raw_bs] for x in parents_list]
283
- return (score_list, token_list, parents_list)
296
+ # Keep the variables name for readability
297
+ parent_list, top_scores_index, draft_tokens = (t[:raw_bs] for t in out)
298
+ return parent_list, top_scores_index, draft_tokens
284
299
 
285
300
  def replay(self, forward_batch: ForwardBatch):
286
301
  assert forward_batch.out_cache_loc is not None
302
+ self.deepep_adapter.replay()
303
+
287
304
  raw_bs = forward_batch.batch_size
288
305
  raw_num_token = raw_bs * self.num_tokens_per_bs
289
306
 
@@ -9,11 +9,13 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
12
+ DeepEPCudaGraphRunnerAdapter,
12
13
  LogitsProcessorOutput,
13
14
  get_batch_sizes_to_capture,
14
15
  get_global_graph_memory_pool,
15
16
  model_capture_mode,
16
17
  set_global_graph_memory_pool,
18
+ set_is_extend_in_batch,
17
19
  set_torch_compile_config,
18
20
  )
19
21
  from sglang.srt.model_executor.forward_batch_info import (
@@ -38,7 +40,12 @@ class EAGLEDraftExtendCudaGraphRunner:
38
40
  def __init__(self, eagle_worker: EAGLEWorker):
39
41
  # Parse args
40
42
  self.eagle_worker = eagle_worker
41
- self.model_runner = model_runner = eagle_worker.model_runner
43
+ if not hasattr(eagle_worker, "model_runner"):
44
+ # V2: EagleDraftWorker
45
+ self.model_runner = model_runner = eagle_worker.draft_runner
46
+ else:
47
+ self.model_runner = model_runner = eagle_worker.model_runner
48
+
42
49
  self.graphs = {}
43
50
  self.output_buffers = {}
44
51
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
@@ -56,6 +63,7 @@ class EAGLEDraftExtendCudaGraphRunner:
56
63
  )
57
64
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
58
65
  self.padded_static_len = -1
66
+ self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
59
67
 
60
68
  # Attention backend
61
69
  self.num_tokens_per_bs = self.speculative_num_steps + 1
@@ -71,6 +79,7 @@ class EAGLEDraftExtendCudaGraphRunner:
71
79
  self.seq_lens_cpu = torch.full(
72
80
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
73
81
  )
82
+ self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs
74
83
 
75
84
  if self.enable_torch_compile:
76
85
  set_torch_compile_config()
@@ -189,7 +198,9 @@ class EAGLEDraftExtendCudaGraphRunner:
189
198
  input_ids = self.input_ids[:num_tokens]
190
199
  req_pool_indices = self.req_pool_indices[:bs]
191
200
  seq_lens = self.seq_lens[:bs]
201
+ seq_lens_cpu = self.seq_lens_cpu[:bs]
192
202
  extend_seq_lens = self.extend_seq_lens[:bs]
203
+ extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs]
193
204
  accept_length = self.accept_length[:bs]
194
205
  out_cache_loc = self.out_cache_loc[:num_tokens]
195
206
  positions = self.positions[:num_tokens]
@@ -238,6 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
238
249
  )
239
250
  spec_info.positions = None
240
251
 
252
+ self.deepep_adapter.capture(is_extend_in_batch=True)
253
+
241
254
  # Forward batch
242
255
  forward_batch = ForwardBatch(
243
256
  forward_mode=ForwardMode.DRAFT_EXTEND,
@@ -245,6 +258,7 @@ class EAGLEDraftExtendCudaGraphRunner:
245
258
  input_ids=input_ids,
246
259
  req_pool_indices=req_pool_indices,
247
260
  seq_lens=seq_lens,
261
+ seq_lens_cpu=seq_lens_cpu,
248
262
  next_token_logits_buffer=next_token_logits_buffer,
249
263
  req_to_token_pool=self.model_runner.req_to_token_pool,
250
264
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
@@ -262,6 +276,7 @@ class EAGLEDraftExtendCudaGraphRunner:
262
276
  capture_hidden_mode=CaptureHiddenMode.LAST,
263
277
  attn_backend=self.eagle_worker.draft_extend_attn_backend,
264
278
  extend_seq_lens=extend_seq_lens,
279
+ extend_seq_lens_cpu=extend_seq_lens_cpu,
265
280
  padded_static_len=self.padded_static_len,
266
281
  )
267
282
 
@@ -280,12 +295,13 @@ class EAGLEDraftExtendCudaGraphRunner:
280
295
  # Clean intermediate result cache for DP attention
281
296
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
282
297
  set_dp_buffer_len(global_dp_buffer_len, num_tokens)
298
+ set_is_extend_in_batch(False)
283
299
 
284
300
  # Backup two fields, which will be modified in-place in `draft_forward`.
285
301
  output_cache_loc_backup = forward_batch.out_cache_loc
286
302
  hidden_states_backup = forward_batch.spec_info.hidden_states
287
303
 
288
- ret = self.eagle_worker.draft_model_runner.model.forward(
304
+ ret = self.model_runner.model.forward(
289
305
  forward_batch.input_ids,
290
306
  forward_batch.positions,
291
307
  forward_batch,
@@ -313,6 +329,8 @@ class EAGLEDraftExtendCudaGraphRunner:
313
329
 
314
330
  def replay(self, forward_batch: ForwardBatch):
315
331
  assert forward_batch.out_cache_loc is not None
332
+ self.deepep_adapter.replay()
333
+
316
334
  # batch_size and num_seqs can be different in case there are finished examples
317
335
  # in the batch, which will not be counted as num_seqs
318
336
  raw_bs = forward_batch.batch_size
@@ -362,6 +380,9 @@ class EAGLEDraftExtendCudaGraphRunner:
362
380
  self.seq_lens_cpu.fill_(self.seq_len_fill_value)
363
381
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
364
382
 
383
+ if forward_batch.extend_seq_lens_cpu is not None:
384
+ self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu
385
+
365
386
  if bs != raw_bs:
366
387
  forward_batch.spec_info.positions = self.positions[:num_tokens]
367
388
  forward_batch.spec_info.accept_length = self.accept_length[:bs]