sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -38,8 +38,11 @@ from sglang.srt.layers.dp_attention import (
38
38
  get_attention_tp_rank,
39
39
  get_attention_tp_size,
40
40
  set_dp_buffer_len,
41
+ set_is_extend_in_batch,
41
42
  )
42
43
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
+ from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
45
+ from sglang.srt.layers.moe.utils import get_deepep_mode, get_moe_a2a_backend
43
46
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
44
47
  from sglang.srt.model_executor.forward_batch_info import (
45
48
  CaptureHiddenMode,
@@ -48,18 +51,28 @@ from sglang.srt.model_executor.forward_batch_info import (
48
51
  PPProxyTensors,
49
52
  enable_num_token_non_padded,
50
53
  )
51
- from sglang.srt.patch_torch import monkey_patch_torch_compile
52
54
  from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
53
55
  from sglang.srt.utils import (
54
56
  empty_context,
55
57
  get_available_gpu_memory,
56
- get_device_memory_capacity,
58
+ get_bool_env_var,
59
+ is_hip,
57
60
  log_info_on_rank0,
58
61
  require_attn_tp_gather,
59
62
  require_gathered_buffer,
60
63
  require_mlp_sync,
61
64
  require_mlp_tp_gather,
62
65
  )
66
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
67
+
68
+ try:
69
+ from kt_kernel import AMXMoEWrapper
70
+
71
+ KTRANSFORMERS_AVAILABLE = True
72
+ except ImportError:
73
+ KTRANSFORMERS_AVAILABLE = False
74
+
75
+ _is_hip = is_hip()
63
76
 
64
77
  logger = logging.getLogger(__name__)
65
78
 
@@ -100,6 +113,7 @@ def freeze_gc(enable_cudagraph_gc: bool):
100
113
  finally:
101
114
  if should_freeze:
102
115
  gc.unfreeze()
116
+ gc.collect()
103
117
 
104
118
 
105
119
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
@@ -136,7 +150,7 @@ def patch_model(
136
150
  mode=os.environ.get(
137
151
  "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
138
152
  ),
139
- dynamic=False,
153
+ dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
140
154
  )
141
155
  else:
142
156
  yield model.forward
@@ -166,29 +180,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
166
180
  server_args = model_runner.server_args
167
181
  capture_bs = server_args.cuda_graph_bs
168
182
 
169
- if capture_bs is None:
170
- if server_args.speculative_algorithm is None:
171
- if server_args.disable_cuda_graph_padding:
172
- capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
173
- else:
174
- capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
175
- else:
176
- # Since speculative decoding requires more cuda graph memory, we
177
- # capture less.
178
- capture_bs = (
179
- list(range(1, 9))
180
- + list(range(10, 33, 2))
181
- + list(range(40, 64, 8))
182
- + list(range(80, 161, 16))
183
- )
184
-
185
- gpu_mem = get_device_memory_capacity()
186
- if gpu_mem is not None:
187
- if gpu_mem > 90 * 1024: # H200, H20
188
- capture_bs += list(range(160, 257, 8))
189
- if gpu_mem > 160 * 1000: # B200, MI300
190
- capture_bs += list(range(256, 513, 16))
191
-
192
183
  if max(capture_bs) > model_runner.req_to_token_pool.size:
193
184
  # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
194
185
  # is very small. We add more values here to make sure we capture the maximum bs.
@@ -204,12 +195,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
204
195
 
205
196
  capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
206
197
 
207
- if server_args.cuda_graph_max_bs:
208
- capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
209
- if max(capture_bs) < server_args.cuda_graph_max_bs:
210
- capture_bs += list(
211
- range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
212
- )
213
198
  capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
214
199
  capture_bs = list(sorted(set(capture_bs)))
215
200
  assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
@@ -265,15 +250,20 @@ class CudaGraphRunner:
265
250
  self.attn_tp_size = get_attention_tp_size()
266
251
  self.attn_tp_rank = get_attention_tp_rank()
267
252
 
253
+ self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
254
+
268
255
  # Batch sizes to capture
269
256
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
270
257
  log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
258
+ if KTRANSFORMERS_AVAILABLE:
259
+ AMXMoEWrapper.set_capture_batch_sizes(self.capture_bs)
271
260
  self.capture_forward_mode = ForwardMode.DECODE
272
261
  self.capture_hidden_mode = CaptureHiddenMode.NULL
273
262
  self.num_tokens_per_bs = 1
274
263
  if (
275
264
  model_runner.spec_algorithm.is_eagle()
276
265
  or model_runner.spec_algorithm.is_standalone()
266
+ or model_runner.spec_algorithm.is_ngram()
277
267
  ):
278
268
  if self.model_runner.is_draft_worker:
279
269
  raise RuntimeError("This should not happen")
@@ -297,7 +287,6 @@ class CudaGraphRunner:
297
287
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
298
288
  )
299
289
 
300
- # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
301
290
  self.encoder_len_fill_value = 0
302
291
  self.seq_lens_cpu = torch.full(
303
292
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
@@ -440,11 +429,21 @@ class CudaGraphRunner:
440
429
  forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
441
430
  )
442
431
 
432
+ is_ngram_supported = (
433
+ (
434
+ forward_batch.batch_size * self.num_tokens_per_bs
435
+ == forward_batch.input_ids.numel()
436
+ )
437
+ if self.model_runner.spec_algorithm.is_ngram()
438
+ else True
439
+ )
440
+
443
441
  return (
444
442
  is_bs_supported
445
443
  and is_encoder_lens_supported
446
444
  and is_tbo_supported
447
445
  and capture_hidden_mode_matches
446
+ and is_ngram_supported
448
447
  )
449
448
 
450
449
  def capture(self) -> None:
@@ -454,6 +453,7 @@ class CudaGraphRunner:
454
453
  activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
455
454
  record_shapes=True,
456
455
  )
456
+ torch.cuda.memory._record_memory_history()
457
457
 
458
458
  # Trigger CUDA graph capture for specific shapes.
459
459
  # Capture the large shapes first so that the smaller shapes
@@ -502,6 +502,8 @@ class CudaGraphRunner:
502
502
  save_gemlite_cache()
503
503
 
504
504
  if self.enable_profile_cuda_graph:
505
+ torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle")
506
+ torch.cuda.memory._record_memory_history(enabled=None)
505
507
  log_message = (
506
508
  "Sorted by CUDA Time:\n"
507
509
  + prof.key_averages(group_by_input_shape=True).table(
@@ -511,6 +513,7 @@ class CudaGraphRunner:
511
513
  + prof.key_averages(group_by_input_shape=True).table(
512
514
  sort_by="cpu_time_total", row_limit=10
513
515
  )
516
+ + "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n"
514
517
  )
515
518
  logger.info(log_message)
516
519
 
@@ -531,6 +534,7 @@ class CudaGraphRunner:
531
534
  input_ids = self.input_ids[:num_tokens]
532
535
  req_pool_indices = self.req_pool_indices[:bs]
533
536
  seq_lens = self.seq_lens[:bs]
537
+ seq_lens_cpu = self.seq_lens_cpu[:bs]
534
538
  out_cache_loc = self.out_cache_loc[:num_tokens]
535
539
  positions = self.positions[:num_tokens]
536
540
  if self.is_encoder_decoder:
@@ -601,6 +605,7 @@ class CudaGraphRunner:
601
605
  input_ids=input_ids,
602
606
  req_pool_indices=req_pool_indices,
603
607
  seq_lens=seq_lens,
608
+ seq_lens_cpu=seq_lens_cpu,
604
609
  next_token_logits_buffer=next_token_logits_buffer,
605
610
  orig_seq_lens=seq_lens,
606
611
  req_to_token_pool=self.model_runner.req_to_token_pool,
@@ -644,6 +649,7 @@ class CudaGraphRunner:
644
649
  # Clean intermediate result cache for DP attention
645
650
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
646
651
  set_dp_buffer_len(global_dp_buffer_len, num_tokens)
652
+ set_is_extend_in_batch(False)
647
653
 
648
654
  kwargs = {}
649
655
  if (
@@ -662,6 +668,8 @@ class CudaGraphRunner:
662
668
  )
663
669
  return logits_output_or_pp_proxy_tensors
664
670
 
671
+ self.deepep_adapter.capture(is_extend_in_batch=False)
672
+
665
673
  for _ in range(2):
666
674
  self.device_module.synchronize()
667
675
  self.model_runner.tp_group.barrier()
@@ -685,8 +693,9 @@ class CudaGraphRunner:
685
693
  capture_hidden_mode_required_by_forward_batch = (
686
694
  forward_batch.capture_hidden_mode
687
695
  )
688
- capture_hidden_mode_required_by_spec_info = getattr(
689
- forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
696
+ capture_hidden_mode_required_by_spec_info = (
697
+ getattr(forward_batch.spec_info, "capture_hidden_mode", None)
698
+ or CaptureHiddenMode.NULL
690
699
  )
691
700
  capture_hidden_mode_required_for_returning_hidden_states = (
692
701
  CaptureHiddenMode.FULL
@@ -804,6 +813,8 @@ class CudaGraphRunner:
804
813
  skip_attn_backend_init: bool = False,
805
814
  pp_proxy_tensors: Optional[PPProxyTensors] = None,
806
815
  ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
816
+ self.deepep_adapter.replay()
817
+
807
818
  if not skip_attn_backend_init:
808
819
  self.replay_prepare(forward_batch, pp_proxy_tensors)
809
820
  else:
@@ -834,7 +845,7 @@ class CudaGraphRunner:
834
845
  self.model_runner.spec_algorithm.is_eagle()
835
846
  or self.model_runner.spec_algorithm.is_standalone()
836
847
  ):
837
- from sglang.srt.speculative.eagle_utils import EagleVerifyInput
848
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
838
849
 
839
850
  if self.model_runner.is_draft_worker:
840
851
  raise RuntimeError("This should not happen.")
@@ -855,6 +866,20 @@ class CudaGraphRunner:
855
866
  seq_lens_cpu=None,
856
867
  )
857
868
 
869
+ elif self.model_runner.spec_algorithm.is_ngram():
870
+ from sglang.srt.speculative.ngram_info import NgramVerifyInput
871
+
872
+ spec_info = NgramVerifyInput(
873
+ draft_token=None,
874
+ tree_mask=self.custom_mask,
875
+ positions=None,
876
+ retrive_index=None,
877
+ retrive_next_token=None,
878
+ retrive_next_sibling=None,
879
+ draft_token_num=self.num_tokens_per_bs,
880
+ )
881
+ spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
882
+
858
883
  return spec_info
859
884
 
860
885
 
@@ -866,3 +891,23 @@ CUDA_GRAPH_CAPTURE_FAILED_MSG = (
866
891
  "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
867
892
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
868
893
  )
894
+
895
+
896
+ class DeepEPCudaGraphRunnerAdapter:
897
+ def __init__(self):
898
+ # Record DeepEP mode used during capture to ensure replay consistency
899
+ self._captured_deepep_mode = None
900
+
901
+ def capture(self, is_extend_in_batch: bool):
902
+ if not get_moe_a2a_backend().is_deepep():
903
+ return
904
+ self._captured_deepep_mode = get_deepep_mode().resolve(
905
+ is_extend_in_batch=is_extend_in_batch
906
+ )
907
+ DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
908
+
909
+ def replay(self):
910
+ if not get_moe_a2a_backend().is_deepep():
911
+ return
912
+ assert self._captured_deepep_mode is not None
913
+ DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
@@ -44,14 +44,9 @@ from sglang.srt.layers.dp_attention import (
44
44
  get_attention_dp_rank,
45
45
  get_attention_tp_size,
46
46
  set_dp_buffer_len,
47
+ set_is_extend_in_batch,
47
48
  )
48
- from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
49
- from sglang.srt.utils import (
50
- flatten_nested_list,
51
- get_compiler_backend,
52
- is_npu,
53
- support_triton,
54
- )
49
+ from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
55
50
 
56
51
  if TYPE_CHECKING:
57
52
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -60,8 +55,7 @@ if TYPE_CHECKING:
60
55
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
61
56
  from sglang.srt.model_executor.model_runner import ModelRunner
62
57
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
63
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
64
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
58
+ from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
65
59
 
66
60
  _is_npu = is_npu()
67
61
 
@@ -82,9 +76,7 @@ class ForwardMode(IntEnum):
82
76
  # Used in speculative decoding: extend a batch in the draft model.
83
77
  DRAFT_EXTEND = auto()
84
78
 
85
- # A dummy first batch to start the pipeline for overlap scheduler.
86
- # It is now used for triggering the sampling_info_done event for the first prefill batch.
87
- DUMMY_FIRST = auto()
79
+ DRAFT_EXTEND_V2 = auto()
88
80
 
89
81
  # Split Prefill for PD multiplexing
90
82
  SPLIT_PREFILL = auto()
@@ -92,11 +84,16 @@ class ForwardMode(IntEnum):
92
84
  def is_prefill(self):
93
85
  return self.is_extend()
94
86
 
95
- def is_extend(self):
87
+ def is_extend(self, include_draft_extend_v2: bool = False):
96
88
  return (
97
89
  self == ForwardMode.EXTEND
98
90
  or self == ForwardMode.MIXED
99
91
  or self == ForwardMode.DRAFT_EXTEND
92
+ or (
93
+ self == ForwardMode.DRAFT_EXTEND_V2
94
+ if include_draft_extend_v2
95
+ else False
96
+ )
100
97
  or self == ForwardMode.TARGET_VERIFY
101
98
  )
102
99
 
@@ -115,14 +112,23 @@ class ForwardMode(IntEnum):
115
112
  def is_target_verify(self):
116
113
  return self == ForwardMode.TARGET_VERIFY
117
114
 
118
- def is_draft_extend(self):
115
+ def is_draft_extend(self, include_v2: bool = False):
116
+ if include_v2:
117
+ return (
118
+ self == ForwardMode.DRAFT_EXTEND_V2 or self == ForwardMode.DRAFT_EXTEND
119
+ )
119
120
  return self == ForwardMode.DRAFT_EXTEND
120
121
 
122
+ def is_draft_extend_v2(self):
123
+ # For fixed shape logits output in v2 eagle worker
124
+ return self == ForwardMode.DRAFT_EXTEND_V2
125
+
121
126
  def is_extend_or_draft_extend_or_mixed(self):
122
127
  return (
123
128
  self == ForwardMode.EXTEND
124
129
  or self == ForwardMode.DRAFT_EXTEND
125
130
  or self == ForwardMode.MIXED
131
+ or self == ForwardMode.SPLIT_PREFILL
126
132
  )
127
133
 
128
134
  def is_cuda_graph(self):
@@ -135,9 +141,6 @@ class ForwardMode(IntEnum):
135
141
  def is_cpu_graph(self):
136
142
  return self == ForwardMode.DECODE
137
143
 
138
- def is_dummy_first(self):
139
- return self == ForwardMode.DUMMY_FIRST
140
-
141
144
  def is_split_prefill(self):
142
145
  return self == ForwardMode.SPLIT_PREFILL
143
146
 
@@ -292,14 +295,18 @@ class ForwardBatch:
292
295
  can_run_dp_cuda_graph: bool = False
293
296
  global_forward_mode: Optional[ForwardMode] = None
294
297
 
298
+ # Whether this batch is prefill-only (no token generation needed)
299
+ is_prefill_only: bool = False
300
+
295
301
  # Speculative decoding
296
- spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
302
+ spec_info: Optional[SpecInput] = None
297
303
  spec_algorithm: SpeculativeAlgorithm = None
298
304
  capture_hidden_mode: CaptureHiddenMode = None
299
305
 
300
306
  # For padding
301
307
  padded_static_len: int = -1 # -1 if not padded
302
308
  num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
309
+ num_token_non_padded_cpu: int = None
303
310
 
304
311
  # For Qwen2-VL
305
312
  mrope_positions: torch.Tensor = None
@@ -338,6 +345,7 @@ class ForwardBatch:
338
345
  is_extend_in_batch=batch.is_extend_in_batch,
339
346
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
340
347
  global_forward_mode=batch.global_forward_mode,
348
+ is_prefill_only=batch.is_prefill_only,
341
349
  lora_ids=batch.lora_ids,
342
350
  sampling_info=batch.sampling_info,
343
351
  req_to_token_pool=model_runner.req_to_token_pool,
@@ -361,36 +369,18 @@ class ForwardBatch:
361
369
  ret.num_token_non_padded = torch.tensor(
362
370
  len(batch.input_ids), dtype=torch.int32
363
371
  ).to(device, non_blocking=True)
372
+ ret.num_token_non_padded_cpu = len(batch.input_ids)
364
373
 
365
374
  # For MLP sync
366
375
  if batch.global_num_tokens is not None:
367
- from sglang.srt.speculative.eagle_utils import (
368
- EagleDraftInput,
369
- EagleVerifyInput,
370
- )
371
-
372
376
  assert batch.global_num_tokens_for_logprob is not None
377
+
373
378
  # process global_num_tokens and global_num_tokens_for_logprob
374
379
  if batch.spec_info is not None:
375
- if isinstance(batch.spec_info, EagleDraftInput):
376
- global_num_tokens = [
377
- x * batch.spec_info.num_tokens_per_batch
378
- for x in batch.global_num_tokens
379
- ]
380
- global_num_tokens_for_logprob = [
381
- x * batch.spec_info.num_tokens_for_logprob_per_batch
382
- for x in batch.global_num_tokens_for_logprob
383
- ]
384
- else:
385
- assert isinstance(batch.spec_info, EagleVerifyInput)
386
- global_num_tokens = [
387
- x * batch.spec_info.draft_token_num
388
- for x in batch.global_num_tokens
389
- ]
390
- global_num_tokens_for_logprob = [
391
- x * batch.spec_info.draft_token_num
392
- for x in batch.global_num_tokens_for_logprob
393
- ]
380
+ spec_info: SpecInput = batch.spec_info
381
+ global_num_tokens, global_num_tokens_for_logprob = (
382
+ spec_info.get_spec_adjusted_global_num_tokens(batch)
383
+ )
394
384
  else:
395
385
  global_num_tokens = batch.global_num_tokens
396
386
  global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
@@ -420,10 +410,12 @@ class ForwardBatch:
420
410
  ret.positions = ret.spec_info.positions
421
411
 
422
412
  # Init position information
423
- if ret.forward_mode.is_decode():
413
+ if ret.forward_mode.is_decode() or ret.forward_mode.is_target_verify():
424
414
  if ret.positions is None:
425
415
  ret.positions = clamp_position(batch.seq_lens)
426
416
  else:
417
+ assert isinstance(batch.extend_seq_lens, list)
418
+ assert isinstance(batch.extend_prefix_lens, list)
427
419
  ret.extend_seq_lens = torch.tensor(
428
420
  batch.extend_seq_lens, dtype=torch.int32
429
421
  ).to(device, non_blocking=True)
@@ -669,9 +661,6 @@ class ForwardBatch:
669
661
  )
670
662
 
671
663
  def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
672
-
673
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
674
-
675
664
  assert self.global_num_tokens_cpu is not None
676
665
  assert self.global_num_tokens_for_logprob_cpu is not None
677
666
 
@@ -709,6 +698,7 @@ class ForwardBatch:
709
698
 
710
699
  self.global_dp_buffer_len = buffer_len
711
700
  set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
701
+ set_is_extend_in_batch(self.is_extend_in_batch)
712
702
 
713
703
  bs = self.batch_size
714
704
 
@@ -757,9 +747,8 @@ class ForwardBatch:
757
747
  self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
758
748
  self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
759
749
  self.global_num_tokens_cpu = global_num_tokens
760
- self.global_num_tokens_gpu = self.global_num_tokens_gpu.new_tensor(
761
- global_num_tokens
762
- )
750
+ global_num_tokens_pinned = torch.tensor(global_num_tokens, pin_memory=True)
751
+ self.global_num_tokens_gpu.copy_(global_num_tokens_pinned, non_blocking=True)
763
752
 
764
753
  if self.mrope_positions is not None:
765
754
  self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
@@ -768,7 +757,8 @@ class ForwardBatch:
768
757
  if self.extend_seq_lens is not None:
769
758
  self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
770
759
 
771
- if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
760
+ if self.spec_info is not None and self.spec_info.is_draft_input():
761
+ # FIXME(lsyin): remove this isinstance logic
772
762
  spec_info = self.spec_info
773
763
  self.output_cache_loc_backup = self.out_cache_loc
774
764
  self.hidden_states_backup = spec_info.hidden_states