sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -28,8 +28,10 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
28
28
  from sglang.srt.layers.radix_attention import AttentionType
29
29
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
30
30
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
31
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
31
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
32
+ from sglang.srt.speculative.spec_info import SpecInput
32
33
  from sglang.srt.utils import (
34
+ get_int_env_var,
33
35
  is_flashinfer_available,
34
36
  is_sm100_supported,
35
37
  next_power_of_2,
@@ -39,11 +41,13 @@ if TYPE_CHECKING:
39
41
  from sglang.srt.layers.radix_attention import RadixAttention
40
42
  from sglang.srt.model_executor.model_runner import ModelRunner
41
43
 
44
+
42
45
  if is_flashinfer_available():
43
46
  from flashinfer import (
44
47
  BatchDecodeWithPagedKVCacheWrapper,
45
48
  BatchPrefillWithPagedKVCacheWrapper,
46
49
  BatchPrefillWithRaggedKVCacheWrapper,
50
+ fast_decode_plan,
47
51
  )
48
52
  from flashinfer.cascade import merge_state
49
53
  from flashinfer.decode import _get_range_buf, get_seq_lens
@@ -122,12 +126,33 @@ class FlashInferAttnBackend(AttentionBackend):
122
126
  ):
123
127
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
124
128
 
129
+ # When deterministic inference is enabled, tensor cores should be used for decode
130
+ # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
131
+ # More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
132
+ self.enable_deterministic = (
133
+ model_runner.server_args.enable_deterministic_inference
134
+ )
135
+ self.prefill_split_tile_size = None
136
+ self.decode_split_tile_size = None
137
+ self.disable_cuda_graph_kv_split = False
138
+ if self.enable_deterministic:
139
+ self.decode_use_tensor_cores = True
140
+ self.prefill_split_tile_size = get_int_env_var(
141
+ "SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
142
+ )
143
+ self.decode_split_tile_size = get_int_env_var(
144
+ "SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
145
+ )
146
+ self.disable_cuda_graph_kv_split = True
147
+ global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
148
+
125
149
  # Allocate buffers
126
150
  global global_workspace_buffer
127
151
  if global_workspace_buffer is None:
128
152
  # different from flashinfer zero_init_global_workspace_buffer
153
+ global_workspace_size = global_config.flashinfer_workspace_size
129
154
  global_workspace_buffer = torch.empty(
130
- global_config.flashinfer_workspace_size,
155
+ global_workspace_size,
131
156
  dtype=torch.uint8,
132
157
  device=model_runner.device,
133
158
  )
@@ -218,6 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
218
243
  decode_wrappers=self.decode_wrappers,
219
244
  encoder_lens=forward_batch.encoder_lens,
220
245
  spec_info=forward_batch.spec_info,
246
+ fixed_split_size=self.decode_split_tile_size,
247
+ disable_split_kv=False,
221
248
  )
222
249
  self.forward_metadata = DecodeMetadata(self.decode_wrappers)
223
250
  elif forward_batch.forward_mode.is_draft_extend():
@@ -257,7 +284,7 @@ class FlashInferAttnBackend(AttentionBackend):
257
284
  use_ragged = False
258
285
  extend_no_prefix = False
259
286
  else:
260
- use_ragged = True
287
+ use_ragged = not self.enable_deterministic
261
288
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
262
289
 
263
290
  self.indices_updater_prefill.update(
@@ -270,6 +297,7 @@ class FlashInferAttnBackend(AttentionBackend):
270
297
  use_ragged=use_ragged,
271
298
  encoder_lens=forward_batch.encoder_lens,
272
299
  spec_info=None,
300
+ fixed_split_size=self.prefill_split_tile_size,
273
301
  )
274
302
  self.forward_metadata = PrefillMetadata(
275
303
  self.prefill_wrappers_paged, use_ragged, extend_no_prefix
@@ -317,7 +345,7 @@ class FlashInferAttnBackend(AttentionBackend):
317
345
  seq_lens: torch.Tensor,
318
346
  encoder_lens: Optional[torch.Tensor],
319
347
  forward_mode: ForwardMode,
320
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
348
+ spec_info: Optional[SpecInput],
321
349
  ):
322
350
  if forward_mode.is_decode_or_idle():
323
351
  decode_wrappers = []
@@ -344,6 +372,8 @@ class FlashInferAttnBackend(AttentionBackend):
344
372
  decode_wrappers=decode_wrappers,
345
373
  encoder_lens=encoder_lens,
346
374
  spec_info=spec_info,
375
+ fixed_split_size=None,
376
+ disable_split_kv=self.disable_cuda_graph_kv_split,
347
377
  )
348
378
  self.decode_cuda_graph_metadata[bs] = decode_wrappers
349
379
  self.forward_metadata = DecodeMetadata(decode_wrappers)
@@ -422,7 +452,7 @@ class FlashInferAttnBackend(AttentionBackend):
422
452
  seq_lens_sum: int,
423
453
  encoder_lens: Optional[torch.Tensor],
424
454
  forward_mode: ForwardMode,
425
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
455
+ spec_info: Optional[SpecInput],
426
456
  seq_lens_cpu: Optional[torch.Tensor],
427
457
  ):
428
458
  if forward_mode.is_decode_or_idle():
@@ -434,6 +464,8 @@ class FlashInferAttnBackend(AttentionBackend):
434
464
  decode_wrappers=self.decode_cuda_graph_metadata[bs],
435
465
  encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
436
466
  spec_info=spec_info,
467
+ fixed_split_size=None,
468
+ disable_split_kv=self.disable_cuda_graph_kv_split,
437
469
  )
438
470
  elif forward_mode.is_target_verify():
439
471
  self.indices_updater_prefill.update(
@@ -638,7 +670,9 @@ class FlashInferIndicesUpdaterDecode:
638
670
  seq_lens_sum: int,
639
671
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
640
672
  encoder_lens: Optional[torch.Tensor],
641
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
673
+ spec_info: Optional[SpecInput],
674
+ fixed_split_size: Optional[int] = None,
675
+ disable_split_kv: Optional[bool] = None,
642
676
  ):
643
677
  # Keep the signature for type checking. It will be assigned during runtime.
644
678
  raise NotImplementedError()
@@ -651,7 +685,9 @@ class FlashInferIndicesUpdaterDecode:
651
685
  seq_lens_sum: int,
652
686
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
653
687
  encoder_lens: Optional[torch.Tensor],
654
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
688
+ spec_info: Optional[SpecInput],
689
+ fixed_split_size: Optional[int] = None,
690
+ disable_split_kv: Optional[bool] = None,
655
691
  ):
656
692
  decode_wrappers = decode_wrappers or self.decode_wrappers
657
693
  self.call_begin_forward(
@@ -663,6 +699,8 @@ class FlashInferIndicesUpdaterDecode:
663
699
  None,
664
700
  spec_info,
665
701
  seq_lens_cpu,
702
+ fixed_split_size=fixed_split_size,
703
+ disable_split_kv=disable_split_kv,
666
704
  )
667
705
 
668
706
  def update_sliding_window(
@@ -673,7 +711,9 @@ class FlashInferIndicesUpdaterDecode:
673
711
  seq_lens_sum: int,
674
712
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
675
713
  encoder_lens: Optional[torch.Tensor],
676
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
714
+ spec_info: Optional[SpecInput],
715
+ fixed_split_size: Optional[int] = None,
716
+ disable_split_kv: Optional[bool] = None,
677
717
  ):
678
718
  assert self.sliding_window_size is not None
679
719
  for wrapper_id in range(2):
@@ -721,7 +761,9 @@ class FlashInferIndicesUpdaterDecode:
721
761
  seq_lens_sum: int,
722
762
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
723
763
  encoder_lens: Optional[torch.Tensor],
724
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
764
+ spec_info: Optional[SpecInput],
765
+ fixed_split_size: Optional[int] = None,
766
+ disable_split_kv: Optional[bool] = None,
725
767
  ):
726
768
  for wrapper_id in range(2):
727
769
  if wrapper_id == 0:
@@ -753,9 +795,11 @@ class FlashInferIndicesUpdaterDecode:
753
795
  paged_kernel_lens_sum: int,
754
796
  kv_indptr: torch.Tensor,
755
797
  kv_start_idx: torch.Tensor,
756
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
798
+ spec_info: Optional[SpecInput],
757
799
  seq_lens_cpu: Optional[torch.Tensor],
758
800
  use_sliding_window_kv_pool: bool = False,
801
+ fixed_split_size: Optional[int] = None,
802
+ disable_split_kv: Optional[bool] = None,
759
803
  ):
760
804
  if spec_info is None:
761
805
  bs = len(req_pool_indices)
@@ -799,19 +843,51 @@ class FlashInferIndicesUpdaterDecode:
799
843
  global_override_indptr_cpu[0] = 0
800
844
  global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
801
845
 
802
- wrapper.begin_forward(
803
- kv_indptr,
804
- kv_indices,
805
- self.kv_last_page_len[:bs],
806
- self.num_qo_heads,
807
- self.num_kv_heads,
808
- self.head_dim,
809
- 1,
810
- data_type=self.data_type,
811
- q_data_type=self.q_data_type,
812
- non_blocking=True,
846
+ # Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
847
+ # by checking if it's a partial function with fast_decode_plan as the func
848
+ wrapper_uses_fast_decode_plan = (
849
+ hasattr(wrapper.begin_forward, "func")
850
+ and wrapper.begin_forward.func == fast_decode_plan
813
851
  )
814
852
 
853
+ if wrapper_uses_fast_decode_plan:
854
+ # When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
855
+ wrapper.begin_forward(
856
+ kv_indptr,
857
+ kv_indices,
858
+ self.kv_last_page_len[:bs],
859
+ self.num_qo_heads,
860
+ self.num_kv_heads,
861
+ self.head_dim,
862
+ 1,
863
+ data_type=self.data_type,
864
+ q_data_type=self.q_data_type,
865
+ non_blocking=True,
866
+ fixed_split_size=fixed_split_size,
867
+ disable_split_kv=(
868
+ disable_split_kv if disable_split_kv is not None else False
869
+ ),
870
+ global_override_indptr_cpu=global_override_indptr_cpu,
871
+ )
872
+ else:
873
+ # When using original begin_forward, don't pass global_override_indptr_cpu
874
+ wrapper.begin_forward(
875
+ kv_indptr,
876
+ kv_indices,
877
+ self.kv_last_page_len[:bs],
878
+ self.num_qo_heads,
879
+ self.num_kv_heads,
880
+ self.head_dim,
881
+ 1,
882
+ data_type=self.data_type,
883
+ q_data_type=self.q_data_type,
884
+ non_blocking=True,
885
+ fixed_split_size=fixed_split_size,
886
+ disable_split_kv=(
887
+ disable_split_kv if disable_split_kv is not None else False
888
+ ),
889
+ )
890
+
815
891
  if locally_override:
816
892
  global_override_indptr_cpu = None
817
893
 
@@ -858,7 +934,8 @@ class FlashInferIndicesUpdaterPrefill:
858
934
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
859
935
  use_ragged: bool,
860
936
  encoder_lens: Optional[torch.Tensor],
861
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
937
+ spec_info: Optional[SpecInput],
938
+ fixed_split_size: Optional[int] = None,
862
939
  ):
863
940
  # Keep the signature for type checking. It will be assigned during runtime.
864
941
  raise NotImplementedError()
@@ -873,7 +950,8 @@ class FlashInferIndicesUpdaterPrefill:
873
950
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
874
951
  use_ragged: bool,
875
952
  encoder_lens: Optional[torch.Tensor],
876
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
953
+ spec_info: Optional[SpecInput],
954
+ fixed_split_size: Optional[int] = None,
877
955
  ):
878
956
  if use_ragged:
879
957
  # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
@@ -897,6 +975,7 @@ class FlashInferIndicesUpdaterPrefill:
897
975
  self.qo_indptr[0],
898
976
  use_ragged,
899
977
  spec_info,
978
+ fixed_split_size=fixed_split_size,
900
979
  )
901
980
 
902
981
  def update_sliding_window(
@@ -909,7 +988,8 @@ class FlashInferIndicesUpdaterPrefill:
909
988
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
910
989
  use_ragged: bool,
911
990
  encoder_lens: Optional[torch.Tensor],
912
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
991
+ spec_info: Optional[SpecInput],
992
+ fixed_split_size: Optional[int] = None,
913
993
  ):
914
994
  for wrapper_id in range(2):
915
995
  if wrapper_id == 0:
@@ -955,7 +1035,8 @@ class FlashInferIndicesUpdaterPrefill:
955
1035
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
956
1036
  use_ragged: bool,
957
1037
  encoder_lens: Optional[torch.Tensor],
958
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1038
+ spec_info: Optional[SpecInput],
1039
+ fixed_split_size: Optional[int] = None,
959
1040
  ):
960
1041
  for wrapper_id in range(2):
961
1042
  if wrapper_id == 0:
@@ -997,8 +1078,9 @@ class FlashInferIndicesUpdaterPrefill:
997
1078
  kv_indptr: torch.Tensor,
998
1079
  qo_indptr: torch.Tensor,
999
1080
  use_ragged: bool,
1000
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1081
+ spec_info: Optional[SpecInput],
1001
1082
  use_sliding_window_kv_pool: bool = False,
1083
+ fixed_split_size: Optional[int] = None,
1002
1084
  ):
1003
1085
  bs = len(seq_lens)
1004
1086
  if spec_info is None:
@@ -1024,9 +1106,7 @@ class FlashInferIndicesUpdaterPrefill:
1024
1106
  qo_indptr = qo_indptr[: bs + 1]
1025
1107
  custom_mask = None
1026
1108
  else:
1027
- assert isinstance(spec_info, EagleDraftInput) or isinstance(
1028
- spec_info, EagleVerifyInput
1029
- )
1109
+ assert isinstance(spec_info, SpecInput)
1030
1110
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
1031
1111
  spec_info.generate_attn_arg_prefill(
1032
1112
  req_pool_indices,
@@ -1069,6 +1149,7 @@ class FlashInferIndicesUpdaterPrefill:
1069
1149
  kv_data_type=self.data_type,
1070
1150
  custom_mask=custom_mask,
1071
1151
  non_blocking=True,
1152
+ fixed_split_size=fixed_split_size,
1072
1153
  )
1073
1154
 
1074
1155
 
@@ -1084,7 +1165,7 @@ class FlashInferMultiStepDraftBackend:
1084
1165
  topk: int,
1085
1166
  speculative_num_steps: int,
1086
1167
  ):
1087
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1168
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
1088
1169
 
1089
1170
  self.topk = topk
1090
1171
  self.speculative_num_steps = speculative_num_steps
@@ -1148,7 +1229,7 @@ class FlashInferMultiStepDraftBackend:
1148
1229
  )
1149
1230
 
1150
1231
  assert forward_batch.spec_info is not None
1151
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
1232
+ assert forward_batch.spec_info.is_draft_input()
1152
1233
 
1153
1234
  # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
1154
1235
  indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
@@ -1276,166 +1357,3 @@ def should_use_tensor_core(
1276
1357
  return gqa_group_size >= 4
1277
1358
  else:
1278
1359
  return False
1279
-
1280
-
1281
- # Use as a fast path to override the indptr in flashinfer's plan function
1282
- # This is used to remove some host-to-device copy overhead.
1283
- global_override_indptr_cpu = None
1284
-
1285
-
1286
- def fast_decode_plan(
1287
- self,
1288
- indptr: torch.Tensor,
1289
- indices: torch.Tensor,
1290
- last_page_len: torch.Tensor,
1291
- num_qo_heads: int,
1292
- num_kv_heads: int,
1293
- head_dim: int,
1294
- page_size: int,
1295
- pos_encoding_mode: str = "NONE",
1296
- window_left: int = -1,
1297
- logits_soft_cap: Optional[float] = None,
1298
- q_data_type: Optional[Union[str, torch.dtype]] = None,
1299
- kv_data_type: Optional[Union[str, torch.dtype]] = None,
1300
- data_type: Optional[Union[str, torch.dtype]] = None,
1301
- sm_scale: Optional[float] = None,
1302
- rope_scale: Optional[float] = None,
1303
- rope_theta: Optional[float] = None,
1304
- non_blocking: bool = True,
1305
- ) -> None:
1306
- """
1307
- A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
1308
- Modifications:
1309
- - Remove unnecessary device-to-device copy for the cuda graph buffers.
1310
- - Remove unnecessary host-to-device copy for the metadata buffers.
1311
- """
1312
- batch_size = len(last_page_len)
1313
- if logits_soft_cap is None:
1314
- logits_soft_cap = 0.0
1315
-
1316
- # Handle data types consistently
1317
- if data_type is not None:
1318
- if q_data_type is None:
1319
- q_data_type = data_type
1320
- if kv_data_type is None:
1321
- kv_data_type = data_type
1322
- elif q_data_type is None:
1323
- q_data_type = "float16"
1324
-
1325
- if kv_data_type is None:
1326
- kv_data_type = q_data_type
1327
-
1328
- if self.use_tensor_cores:
1329
- qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
1330
-
1331
- if self.is_cuda_graph_enabled:
1332
- if batch_size != self._fixed_batch_size:
1333
- raise ValueError(
1334
- "The batch size should be fixed in cudagraph mode, the runtime batch size {} "
1335
- " mismatches the batch size set during initialization {}".format(
1336
- batch_size, self._fixed_batch_size
1337
- )
1338
- )
1339
- if len(indices) > len(self._paged_kv_indices_buf):
1340
- raise ValueError(
1341
- "The size of indices should be less than or equal to the allocated buffer"
1342
- )
1343
- else:
1344
- self._paged_kv_indptr_buf = indptr
1345
- self._paged_kv_indices_buf = indices
1346
- self._paged_kv_last_page_len_buf = last_page_len
1347
- if self.use_tensor_cores:
1348
- self._qo_indptr_buf = qo_indptr_host.to(
1349
- self.device, non_blocking=non_blocking
1350
- )
1351
-
1352
- # Create empty tensors for dtype info if needed
1353
- empty_q_data = torch.empty(
1354
- 0,
1355
- dtype=(
1356
- getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
1357
- ),
1358
- device=self.device,
1359
- )
1360
-
1361
- empty_kv_cache = torch.empty(
1362
- 0,
1363
- dtype=(
1364
- getattr(torch, kv_data_type)
1365
- if isinstance(kv_data_type, str)
1366
- else kv_data_type
1367
- ),
1368
- device=self.device,
1369
- )
1370
-
1371
- indptr_host = (
1372
- global_override_indptr_cpu
1373
- if global_override_indptr_cpu is not None
1374
- else indptr.cpu()
1375
- )
1376
-
1377
- with torch.cuda.device(self.device):
1378
-
1379
- if self.use_tensor_cores:
1380
- # ALSO convert last_page_len to CPU
1381
- if page_size == 1:
1382
- # When page size is 1, last_page_len is always 1.
1383
- # Directly construct the host tensor rather than executing a device-to-host copy.
1384
- last_page_len_host = torch.ones(
1385
- (batch_size,), dtype=torch.int32, device="cpu"
1386
- )
1387
- else:
1388
- last_page_len_host = last_page_len.cpu()
1389
-
1390
- kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1391
-
1392
- try:
1393
- # Make sure we pass exactly 15 arguments for tensor core version
1394
- self._plan_info = self._cached_module.plan(
1395
- self._float_workspace_buffer,
1396
- self._int_workspace_buffer,
1397
- self._pin_memory_int_workspace_buffer,
1398
- qo_indptr_host,
1399
- indptr_host,
1400
- kv_lens_arr_host,
1401
- batch_size, # total_num_rows
1402
- batch_size,
1403
- num_qo_heads,
1404
- num_kv_heads,
1405
- page_size,
1406
- self.is_cuda_graph_enabled,
1407
- head_dim,
1408
- head_dim,
1409
- False, # causal
1410
- )
1411
- except Exception as e:
1412
- raise RuntimeError(f"Error in standard plan: {e}")
1413
- else:
1414
- try:
1415
- # Make sure we pass exactly 15 arguments for standard version
1416
- self._plan_info = self._cached_module.plan(
1417
- self._float_workspace_buffer,
1418
- self._int_workspace_buffer,
1419
- self._pin_memory_int_workspace_buffer,
1420
- indptr_host,
1421
- batch_size,
1422
- num_qo_heads,
1423
- num_kv_heads,
1424
- page_size,
1425
- self.is_cuda_graph_enabled,
1426
- window_left,
1427
- logits_soft_cap,
1428
- head_dim,
1429
- head_dim,
1430
- empty_q_data,
1431
- empty_kv_cache,
1432
- )
1433
- except Exception as e:
1434
- raise RuntimeError(f"Error in standard plan: {e}")
1435
-
1436
- self._pos_encoding_mode = pos_encoding_mode
1437
- self._window_left = window_left
1438
- self._logits_soft_cap = logits_soft_cap
1439
- self._sm_scale = sm_scale
1440
- self._rope_scale = rope_scale
1441
- self._rope_theta = rope_theta
@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
30
30
  from sglang.srt.layers.dp_attention import get_attention_tp_size
31
31
  from sglang.srt.managers.schedule_batch import global_server_args_dict
32
32
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
33
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
33
+ from sglang.srt.speculative.spec_info import SpecInput
34
34
  from sglang.srt.utils import (
35
35
  is_flashinfer_available,
36
36
  is_sm100_supported,
@@ -40,7 +40,7 @@ from sglang.srt.utils import (
40
40
  if TYPE_CHECKING:
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
42
  from sglang.srt.model_executor.model_runner import ModelRunner
43
- from sglang.srt.speculative.spec_info import SpecInfo
43
+ from sglang.srt.speculative.spec_info import SpecInput
44
44
 
45
45
  if is_flashinfer_available():
46
46
  from flashinfer import (
@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
361
361
  seq_lens: torch.Tensor,
362
362
  encoder_lens: Optional[torch.Tensor],
363
363
  forward_mode: ForwardMode,
364
- spec_info: Optional[SpecInfo],
364
+ spec_info: Optional[SpecInput],
365
365
  ):
366
366
  if forward_mode.is_decode_or_idle():
367
367
  decode_wrapper = BatchMLAPagedAttentionWrapper(
@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
441
441
  seq_lens_sum: int,
442
442
  encoder_lens: Optional[torch.Tensor],
443
443
  forward_mode: ForwardMode,
444
- spec_info: Optional[SpecInfo],
444
+ spec_info: Optional[SpecInput],
445
445
  seq_lens_cpu: Optional[torch.Tensor],
446
446
  ):
447
447
  if forward_mode.is_decode_or_idle():
@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
663
663
  seq_lens_sum: int,
664
664
  decode_wrapper: BatchMLAPagedAttentionWrapper,
665
665
  init_metadata_replay: bool = False,
666
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
666
+ spec_info: Optional[SpecInput] = None,
667
667
  **fast_decode_kwargs,
668
668
  ):
669
669
  decode_wrapper = decode_wrapper or self.decode_wrapper
@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
688
688
  q_indptr: torch.Tensor,
689
689
  kv_indptr: torch.Tensor,
690
690
  init_metadata_replay: bool = False,
691
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
691
+ spec_info: Optional[SpecInput] = None,
692
692
  **fast_decode_kwargs,
693
693
  ):
694
694
  bs = len(req_pool_indices)
@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
776
776
  prefix_lens: torch.Tensor,
777
777
  prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
778
778
  use_ragged: bool,
779
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
779
+ spec_info: Optional[SpecInput] = None,
780
780
  ):
781
781
  if use_ragged:
782
782
  paged_kernel_lens = prefix_lens
@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
811
811
  kv_indptr: torch.Tensor,
812
812
  qo_indptr: torch.Tensor,
813
813
  use_ragged: bool,
814
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
814
+ spec_info: Optional[SpecInput] = None,
815
815
  ):
816
816
  bs = len(seq_lens)
817
817
  sm_scale = self.scaling
@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
838
838
  qo_indptr = qo_indptr[: bs + 1]
839
839
  custom_mask = None
840
840
  else:
841
- assert isinstance(spec_info, EagleDraftInput) or isinstance(
842
- spec_info, EagleVerifyInput
843
- )
841
+ assert isinstance(spec_info, SpecInput)
844
842
  # TODO: Support topk > 1 with custom mask
845
843
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
846
844
  spec_info.generate_attn_arg_prefill(
@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
894
892
  topk: int,
895
893
  speculative_num_steps: int,
896
894
  ):
897
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
895
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
898
896
 
899
897
  if topk > 1:
900
898
  raise ValueError(
@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
963
961
  )
964
962
 
965
963
  assert forward_batch.spec_info is not None
966
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
964
+ assert forward_batch.spec_info.is_draft_input()
967
965
 
968
966
  for i in range(self.speculative_num_steps - 1):
969
967
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
983
981
  )
984
982
 
985
983
  def call_fn(i, forward_batch):
986
- assert forward_batch.spec_info is not None
987
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
988
984
  forward_batch.spec_info.kv_indptr = (
989
985
  forward_batch.spec_info.kv_indptr.clone()
990
986
  )
@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
19
19
  if TYPE_CHECKING:
20
20
  from sglang.srt.layers.radix_attention import RadixAttention
21
21
  from sglang.srt.model_executor.model_runner import ModelRunner
22
- from sglang.srt.speculative.spec_info import SpecInfo
22
+ from sglang.srt.speculative.spec_info import SpecInput
23
23
 
24
24
 
25
25
  # FlashMLA only supports pagesize=64
@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
187
187
  seq_lens: torch.Tensor,
188
188
  encoder_lens: Optional[torch.Tensor],
189
189
  forward_mode: ForwardMode,
190
- spec_info: Optional[SpecInfo],
190
+ spec_info: Optional[SpecInput],
191
191
  ):
192
192
  if forward_mode.is_decode_or_idle():
193
193
  max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
201
201
  self.req_to_token.stride(0),
202
202
  self.cuda_graph_kv_indices.stride(0),
203
203
  )
204
+ num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
204
205
  mla_metadata, num_splits = get_mla_metadata(
205
206
  seq_lens.to(torch.int32),
206
- self.num_q_heads,
207
+ num_q_heads,
207
208
  1,
208
209
  )
209
210
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
@@ -257,7 +258,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
257
258
  seq_lens_sum: int,
258
259
  encoder_lens: Optional[torch.Tensor],
259
260
  forward_mode: ForwardMode,
260
- spec_info: Optional[SpecInfo],
261
+ spec_info: Optional[SpecInput],
261
262
  seq_lens_cpu: Optional[torch.Tensor],
262
263
  ):
263
264
 
@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
275
276
  self.req_to_token.stride(0),
276
277
  self.cuda_graph_kv_indices.stride(0),
277
278
  )
279
+ num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
278
280
  mla_metadata, num_splits = get_mla_metadata(
279
281
  seq_lens.to(torch.int32),
280
- self.num_q_heads,
282
+ num_q_heads,
281
283
  1,
282
284
  )
283
285
  self.cuda_graph_mla_metadata.copy_(mla_metadata)