sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -28,8 +28,10 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
28
28
  from sglang.srt.layers.radix_attention import AttentionType
29
29
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
30
30
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
31
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
31
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
32
+ from sglang.srt.speculative.spec_info import SpecInput
32
33
  from sglang.srt.utils import (
34
+ get_int_env_var,
33
35
  is_flashinfer_available,
34
36
  is_sm100_supported,
35
37
  next_power_of_2,
@@ -39,11 +41,13 @@ if TYPE_CHECKING:
39
41
  from sglang.srt.layers.radix_attention import RadixAttention
40
42
  from sglang.srt.model_executor.model_runner import ModelRunner
41
43
 
44
+
42
45
  if is_flashinfer_available():
43
46
  from flashinfer import (
44
47
  BatchDecodeWithPagedKVCacheWrapper,
45
48
  BatchPrefillWithPagedKVCacheWrapper,
46
49
  BatchPrefillWithRaggedKVCacheWrapper,
50
+ fast_decode_plan,
47
51
  )
48
52
  from flashinfer.cascade import merge_state
49
53
  from flashinfer.decode import _get_range_buf, get_seq_lens
@@ -122,12 +126,33 @@ class FlashInferAttnBackend(AttentionBackend):
122
126
  ):
123
127
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
124
128
 
129
+ # When deterministic inference is enabled, tensor cores should be used for decode
130
+ # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
131
+ # More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
132
+ self.enable_deterministic = (
133
+ model_runner.server_args.enable_deterministic_inference
134
+ )
135
+ self.prefill_split_tile_size = None
136
+ self.decode_split_tile_size = None
137
+ self.disable_cuda_graph_kv_split = False
138
+ if self.enable_deterministic:
139
+ self.decode_use_tensor_cores = True
140
+ self.prefill_split_tile_size = get_int_env_var(
141
+ "SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
142
+ )
143
+ self.decode_split_tile_size = get_int_env_var(
144
+ "SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
145
+ )
146
+ self.disable_cuda_graph_kv_split = True
147
+ global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
148
+
125
149
  # Allocate buffers
126
150
  global global_workspace_buffer
127
151
  if global_workspace_buffer is None:
128
152
  # different from flashinfer zero_init_global_workspace_buffer
153
+ global_workspace_size = global_config.flashinfer_workspace_size
129
154
  global_workspace_buffer = torch.empty(
130
- global_config.flashinfer_workspace_size,
155
+ global_workspace_size,
131
156
  dtype=torch.uint8,
132
157
  device=model_runner.device,
133
158
  )
@@ -218,6 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
218
243
  decode_wrappers=self.decode_wrappers,
219
244
  encoder_lens=forward_batch.encoder_lens,
220
245
  spec_info=forward_batch.spec_info,
246
+ fixed_split_size=self.decode_split_tile_size,
247
+ disable_split_kv=False,
221
248
  )
222
249
  self.forward_metadata = DecodeMetadata(self.decode_wrappers)
223
250
  elif forward_batch.forward_mode.is_draft_extend():
@@ -257,7 +284,7 @@ class FlashInferAttnBackend(AttentionBackend):
257
284
  use_ragged = False
258
285
  extend_no_prefix = False
259
286
  else:
260
- use_ragged = True
287
+ use_ragged = not self.enable_deterministic
261
288
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
262
289
 
263
290
  self.indices_updater_prefill.update(
@@ -270,6 +297,7 @@ class FlashInferAttnBackend(AttentionBackend):
270
297
  use_ragged=use_ragged,
271
298
  encoder_lens=forward_batch.encoder_lens,
272
299
  spec_info=None,
300
+ fixed_split_size=self.prefill_split_tile_size,
273
301
  )
274
302
  self.forward_metadata = PrefillMetadata(
275
303
  self.prefill_wrappers_paged, use_ragged, extend_no_prefix
@@ -317,7 +345,7 @@ class FlashInferAttnBackend(AttentionBackend):
317
345
  seq_lens: torch.Tensor,
318
346
  encoder_lens: Optional[torch.Tensor],
319
347
  forward_mode: ForwardMode,
320
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
348
+ spec_info: Optional[SpecInput],
321
349
  ):
322
350
  if forward_mode.is_decode_or_idle():
323
351
  decode_wrappers = []
@@ -344,6 +372,8 @@ class FlashInferAttnBackend(AttentionBackend):
344
372
  decode_wrappers=decode_wrappers,
345
373
  encoder_lens=encoder_lens,
346
374
  spec_info=spec_info,
375
+ fixed_split_size=None,
376
+ disable_split_kv=self.disable_cuda_graph_kv_split,
347
377
  )
348
378
  self.decode_cuda_graph_metadata[bs] = decode_wrappers
349
379
  self.forward_metadata = DecodeMetadata(decode_wrappers)
@@ -422,7 +452,7 @@ class FlashInferAttnBackend(AttentionBackend):
422
452
  seq_lens_sum: int,
423
453
  encoder_lens: Optional[torch.Tensor],
424
454
  forward_mode: ForwardMode,
425
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
455
+ spec_info: Optional[SpecInput],
426
456
  seq_lens_cpu: Optional[torch.Tensor],
427
457
  ):
428
458
  if forward_mode.is_decode_or_idle():
@@ -434,6 +464,8 @@ class FlashInferAttnBackend(AttentionBackend):
434
464
  decode_wrappers=self.decode_cuda_graph_metadata[bs],
435
465
  encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
436
466
  spec_info=spec_info,
467
+ fixed_split_size=None,
468
+ disable_split_kv=self.disable_cuda_graph_kv_split,
437
469
  )
438
470
  elif forward_mode.is_target_verify():
439
471
  self.indices_updater_prefill.update(
@@ -501,8 +533,9 @@ class FlashInferAttnBackend(AttentionBackend):
501
533
  sm_scale=layer.scaling,
502
534
  window_left=layer.sliding_window_size,
503
535
  logits_soft_cap=logits_soft_cap,
504
- k_scale=layer.k_scale,
505
- v_scale=layer.v_scale,
536
+ # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
537
+ k_scale=layer.k_scale_float,
538
+ v_scale=layer.v_scale_float,
506
539
  )
507
540
  else:
508
541
  causal = True
@@ -580,8 +613,9 @@ class FlashInferAttnBackend(AttentionBackend):
580
613
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
581
614
  sm_scale=layer.scaling,
582
615
  logits_soft_cap=layer.logit_cap,
583
- k_scale=layer.k_scale,
584
- v_scale=layer.v_scale,
616
+ # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
617
+ k_scale=layer.k_scale_float,
618
+ v_scale=layer.v_scale_float,
585
619
  )
586
620
 
587
621
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -636,7 +670,9 @@ class FlashInferIndicesUpdaterDecode:
636
670
  seq_lens_sum: int,
637
671
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
638
672
  encoder_lens: Optional[torch.Tensor],
639
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
673
+ spec_info: Optional[SpecInput],
674
+ fixed_split_size: Optional[int] = None,
675
+ disable_split_kv: Optional[bool] = None,
640
676
  ):
641
677
  # Keep the signature for type checking. It will be assigned during runtime.
642
678
  raise NotImplementedError()
@@ -649,7 +685,9 @@ class FlashInferIndicesUpdaterDecode:
649
685
  seq_lens_sum: int,
650
686
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
651
687
  encoder_lens: Optional[torch.Tensor],
652
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
688
+ spec_info: Optional[SpecInput],
689
+ fixed_split_size: Optional[int] = None,
690
+ disable_split_kv: Optional[bool] = None,
653
691
  ):
654
692
  decode_wrappers = decode_wrappers or self.decode_wrappers
655
693
  self.call_begin_forward(
@@ -661,6 +699,8 @@ class FlashInferIndicesUpdaterDecode:
661
699
  None,
662
700
  spec_info,
663
701
  seq_lens_cpu,
702
+ fixed_split_size=fixed_split_size,
703
+ disable_split_kv=disable_split_kv,
664
704
  )
665
705
 
666
706
  def update_sliding_window(
@@ -671,7 +711,9 @@ class FlashInferIndicesUpdaterDecode:
671
711
  seq_lens_sum: int,
672
712
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
673
713
  encoder_lens: Optional[torch.Tensor],
674
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
714
+ spec_info: Optional[SpecInput],
715
+ fixed_split_size: Optional[int] = None,
716
+ disable_split_kv: Optional[bool] = None,
675
717
  ):
676
718
  assert self.sliding_window_size is not None
677
719
  for wrapper_id in range(2):
@@ -719,7 +761,9 @@ class FlashInferIndicesUpdaterDecode:
719
761
  seq_lens_sum: int,
720
762
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
721
763
  encoder_lens: Optional[torch.Tensor],
722
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
764
+ spec_info: Optional[SpecInput],
765
+ fixed_split_size: Optional[int] = None,
766
+ disable_split_kv: Optional[bool] = None,
723
767
  ):
724
768
  for wrapper_id in range(2):
725
769
  if wrapper_id == 0:
@@ -751,9 +795,11 @@ class FlashInferIndicesUpdaterDecode:
751
795
  paged_kernel_lens_sum: int,
752
796
  kv_indptr: torch.Tensor,
753
797
  kv_start_idx: torch.Tensor,
754
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
798
+ spec_info: Optional[SpecInput],
755
799
  seq_lens_cpu: Optional[torch.Tensor],
756
800
  use_sliding_window_kv_pool: bool = False,
801
+ fixed_split_size: Optional[int] = None,
802
+ disable_split_kv: Optional[bool] = None,
757
803
  ):
758
804
  if spec_info is None:
759
805
  bs = len(req_pool_indices)
@@ -797,19 +843,51 @@ class FlashInferIndicesUpdaterDecode:
797
843
  global_override_indptr_cpu[0] = 0
798
844
  global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
799
845
 
800
- wrapper.begin_forward(
801
- kv_indptr,
802
- kv_indices,
803
- self.kv_last_page_len[:bs],
804
- self.num_qo_heads,
805
- self.num_kv_heads,
806
- self.head_dim,
807
- 1,
808
- data_type=self.data_type,
809
- q_data_type=self.q_data_type,
810
- non_blocking=True,
846
+ # Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
847
+ # by checking if it's a partial function with fast_decode_plan as the func
848
+ wrapper_uses_fast_decode_plan = (
849
+ hasattr(wrapper.begin_forward, "func")
850
+ and wrapper.begin_forward.func == fast_decode_plan
811
851
  )
812
852
 
853
+ if wrapper_uses_fast_decode_plan:
854
+ # When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
855
+ wrapper.begin_forward(
856
+ kv_indptr,
857
+ kv_indices,
858
+ self.kv_last_page_len[:bs],
859
+ self.num_qo_heads,
860
+ self.num_kv_heads,
861
+ self.head_dim,
862
+ 1,
863
+ data_type=self.data_type,
864
+ q_data_type=self.q_data_type,
865
+ non_blocking=True,
866
+ fixed_split_size=fixed_split_size,
867
+ disable_split_kv=(
868
+ disable_split_kv if disable_split_kv is not None else False
869
+ ),
870
+ global_override_indptr_cpu=global_override_indptr_cpu,
871
+ )
872
+ else:
873
+ # When using original begin_forward, don't pass global_override_indptr_cpu
874
+ wrapper.begin_forward(
875
+ kv_indptr,
876
+ kv_indices,
877
+ self.kv_last_page_len[:bs],
878
+ self.num_qo_heads,
879
+ self.num_kv_heads,
880
+ self.head_dim,
881
+ 1,
882
+ data_type=self.data_type,
883
+ q_data_type=self.q_data_type,
884
+ non_blocking=True,
885
+ fixed_split_size=fixed_split_size,
886
+ disable_split_kv=(
887
+ disable_split_kv if disable_split_kv is not None else False
888
+ ),
889
+ )
890
+
813
891
  if locally_override:
814
892
  global_override_indptr_cpu = None
815
893
 
@@ -856,7 +934,8 @@ class FlashInferIndicesUpdaterPrefill:
856
934
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
857
935
  use_ragged: bool,
858
936
  encoder_lens: Optional[torch.Tensor],
859
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
937
+ spec_info: Optional[SpecInput],
938
+ fixed_split_size: Optional[int] = None,
860
939
  ):
861
940
  # Keep the signature for type checking. It will be assigned during runtime.
862
941
  raise NotImplementedError()
@@ -871,7 +950,8 @@ class FlashInferIndicesUpdaterPrefill:
871
950
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
872
951
  use_ragged: bool,
873
952
  encoder_lens: Optional[torch.Tensor],
874
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
953
+ spec_info: Optional[SpecInput],
954
+ fixed_split_size: Optional[int] = None,
875
955
  ):
876
956
  if use_ragged:
877
957
  # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
@@ -895,6 +975,7 @@ class FlashInferIndicesUpdaterPrefill:
895
975
  self.qo_indptr[0],
896
976
  use_ragged,
897
977
  spec_info,
978
+ fixed_split_size=fixed_split_size,
898
979
  )
899
980
 
900
981
  def update_sliding_window(
@@ -907,7 +988,8 @@ class FlashInferIndicesUpdaterPrefill:
907
988
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
908
989
  use_ragged: bool,
909
990
  encoder_lens: Optional[torch.Tensor],
910
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
991
+ spec_info: Optional[SpecInput],
992
+ fixed_split_size: Optional[int] = None,
911
993
  ):
912
994
  for wrapper_id in range(2):
913
995
  if wrapper_id == 0:
@@ -953,7 +1035,8 @@ class FlashInferIndicesUpdaterPrefill:
953
1035
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
954
1036
  use_ragged: bool,
955
1037
  encoder_lens: Optional[torch.Tensor],
956
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1038
+ spec_info: Optional[SpecInput],
1039
+ fixed_split_size: Optional[int] = None,
957
1040
  ):
958
1041
  for wrapper_id in range(2):
959
1042
  if wrapper_id == 0:
@@ -995,8 +1078,9 @@ class FlashInferIndicesUpdaterPrefill:
995
1078
  kv_indptr: torch.Tensor,
996
1079
  qo_indptr: torch.Tensor,
997
1080
  use_ragged: bool,
998
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1081
+ spec_info: Optional[SpecInput],
999
1082
  use_sliding_window_kv_pool: bool = False,
1083
+ fixed_split_size: Optional[int] = None,
1000
1084
  ):
1001
1085
  bs = len(seq_lens)
1002
1086
  if spec_info is None:
@@ -1022,9 +1106,7 @@ class FlashInferIndicesUpdaterPrefill:
1022
1106
  qo_indptr = qo_indptr[: bs + 1]
1023
1107
  custom_mask = None
1024
1108
  else:
1025
- assert isinstance(spec_info, EagleDraftInput) or isinstance(
1026
- spec_info, EagleVerifyInput
1027
- )
1109
+ assert isinstance(spec_info, SpecInput)
1028
1110
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
1029
1111
  spec_info.generate_attn_arg_prefill(
1030
1112
  req_pool_indices,
@@ -1067,6 +1149,7 @@ class FlashInferIndicesUpdaterPrefill:
1067
1149
  kv_data_type=self.data_type,
1068
1150
  custom_mask=custom_mask,
1069
1151
  non_blocking=True,
1152
+ fixed_split_size=fixed_split_size,
1070
1153
  )
1071
1154
 
1072
1155
 
@@ -1082,7 +1165,7 @@ class FlashInferMultiStepDraftBackend:
1082
1165
  topk: int,
1083
1166
  speculative_num_steps: int,
1084
1167
  ):
1085
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1168
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
1086
1169
 
1087
1170
  self.topk = topk
1088
1171
  self.speculative_num_steps = speculative_num_steps
@@ -1146,7 +1229,7 @@ class FlashInferMultiStepDraftBackend:
1146
1229
  )
1147
1230
 
1148
1231
  assert forward_batch.spec_info is not None
1149
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
1232
+ assert forward_batch.spec_info.is_draft_input()
1150
1233
 
1151
1234
  # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
1152
1235
  indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
@@ -1274,166 +1357,3 @@ def should_use_tensor_core(
1274
1357
  return gqa_group_size >= 4
1275
1358
  else:
1276
1359
  return False
1277
-
1278
-
1279
- # Use as a fast path to override the indptr in flashinfer's plan function
1280
- # This is used to remove some host-to-device copy overhead.
1281
- global_override_indptr_cpu = None
1282
-
1283
-
1284
- def fast_decode_plan(
1285
- self,
1286
- indptr: torch.Tensor,
1287
- indices: torch.Tensor,
1288
- last_page_len: torch.Tensor,
1289
- num_qo_heads: int,
1290
- num_kv_heads: int,
1291
- head_dim: int,
1292
- page_size: int,
1293
- pos_encoding_mode: str = "NONE",
1294
- window_left: int = -1,
1295
- logits_soft_cap: Optional[float] = None,
1296
- q_data_type: Optional[Union[str, torch.dtype]] = None,
1297
- kv_data_type: Optional[Union[str, torch.dtype]] = None,
1298
- data_type: Optional[Union[str, torch.dtype]] = None,
1299
- sm_scale: Optional[float] = None,
1300
- rope_scale: Optional[float] = None,
1301
- rope_theta: Optional[float] = None,
1302
- non_blocking: bool = True,
1303
- ) -> None:
1304
- """
1305
- A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
1306
- Modifications:
1307
- - Remove unnecessary device-to-device copy for the cuda graph buffers.
1308
- - Remove unnecessary host-to-device copy for the metadata buffers.
1309
- """
1310
- batch_size = len(last_page_len)
1311
- if logits_soft_cap is None:
1312
- logits_soft_cap = 0.0
1313
-
1314
- # Handle data types consistently
1315
- if data_type is not None:
1316
- if q_data_type is None:
1317
- q_data_type = data_type
1318
- if kv_data_type is None:
1319
- kv_data_type = data_type
1320
- elif q_data_type is None:
1321
- q_data_type = "float16"
1322
-
1323
- if kv_data_type is None:
1324
- kv_data_type = q_data_type
1325
-
1326
- if self.use_tensor_cores:
1327
- qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
1328
-
1329
- if self.is_cuda_graph_enabled:
1330
- if batch_size != self._fixed_batch_size:
1331
- raise ValueError(
1332
- "The batch size should be fixed in cudagraph mode, the runtime batch size {} "
1333
- " mismatches the batch size set during initialization {}".format(
1334
- batch_size, self._fixed_batch_size
1335
- )
1336
- )
1337
- if len(indices) > len(self._paged_kv_indices_buf):
1338
- raise ValueError(
1339
- "The size of indices should be less than or equal to the allocated buffer"
1340
- )
1341
- else:
1342
- self._paged_kv_indptr_buf = indptr
1343
- self._paged_kv_indices_buf = indices
1344
- self._paged_kv_last_page_len_buf = last_page_len
1345
- if self.use_tensor_cores:
1346
- self._qo_indptr_buf = qo_indptr_host.to(
1347
- self.device, non_blocking=non_blocking
1348
- )
1349
-
1350
- # Create empty tensors for dtype info if needed
1351
- empty_q_data = torch.empty(
1352
- 0,
1353
- dtype=(
1354
- getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
1355
- ),
1356
- device=self.device,
1357
- )
1358
-
1359
- empty_kv_cache = torch.empty(
1360
- 0,
1361
- dtype=(
1362
- getattr(torch, kv_data_type)
1363
- if isinstance(kv_data_type, str)
1364
- else kv_data_type
1365
- ),
1366
- device=self.device,
1367
- )
1368
-
1369
- indptr_host = (
1370
- global_override_indptr_cpu
1371
- if global_override_indptr_cpu is not None
1372
- else indptr.cpu()
1373
- )
1374
-
1375
- with torch.cuda.device(self.device):
1376
-
1377
- if self.use_tensor_cores:
1378
- # ALSO convert last_page_len to CPU
1379
- if page_size == 1:
1380
- # When page size is 1, last_page_len is always 1.
1381
- # Directly construct the host tensor rather than executing a device-to-host copy.
1382
- last_page_len_host = torch.ones(
1383
- (batch_size,), dtype=torch.int32, device="cpu"
1384
- )
1385
- else:
1386
- last_page_len_host = last_page_len.cpu()
1387
-
1388
- kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1389
-
1390
- try:
1391
- # Make sure we pass exactly 15 arguments for tensor core version
1392
- self._plan_info = self._cached_module.plan(
1393
- self._float_workspace_buffer,
1394
- self._int_workspace_buffer,
1395
- self._pin_memory_int_workspace_buffer,
1396
- qo_indptr_host,
1397
- indptr_host,
1398
- kv_lens_arr_host,
1399
- batch_size, # total_num_rows
1400
- batch_size,
1401
- num_qo_heads,
1402
- num_kv_heads,
1403
- page_size,
1404
- self.is_cuda_graph_enabled,
1405
- head_dim,
1406
- head_dim,
1407
- False, # causal
1408
- )
1409
- except Exception as e:
1410
- raise RuntimeError(f"Error in standard plan: {e}")
1411
- else:
1412
- try:
1413
- # Make sure we pass exactly 15 arguments for standard version
1414
- self._plan_info = self._cached_module.plan(
1415
- self._float_workspace_buffer,
1416
- self._int_workspace_buffer,
1417
- self._pin_memory_int_workspace_buffer,
1418
- indptr_host,
1419
- batch_size,
1420
- num_qo_heads,
1421
- num_kv_heads,
1422
- page_size,
1423
- self.is_cuda_graph_enabled,
1424
- window_left,
1425
- logits_soft_cap,
1426
- head_dim,
1427
- head_dim,
1428
- empty_q_data,
1429
- empty_kv_cache,
1430
- )
1431
- except Exception as e:
1432
- raise RuntimeError(f"Error in standard plan: {e}")
1433
-
1434
- self._pos_encoding_mode = pos_encoding_mode
1435
- self._window_left = window_left
1436
- self._logits_soft_cap = logits_soft_cap
1437
- self._sm_scale = sm_scale
1438
- self._rope_scale = rope_scale
1439
- self._rope_theta = rope_theta
@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
30
30
  from sglang.srt.layers.dp_attention import get_attention_tp_size
31
31
  from sglang.srt.managers.schedule_batch import global_server_args_dict
32
32
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
33
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
33
+ from sglang.srt.speculative.spec_info import SpecInput
34
34
  from sglang.srt.utils import (
35
35
  is_flashinfer_available,
36
36
  is_sm100_supported,
@@ -40,7 +40,7 @@ from sglang.srt.utils import (
40
40
  if TYPE_CHECKING:
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
42
  from sglang.srt.model_executor.model_runner import ModelRunner
43
- from sglang.srt.speculative.spec_info import SpecInfo
43
+ from sglang.srt.speculative.spec_info import SpecInput
44
44
 
45
45
  if is_flashinfer_available():
46
46
  from flashinfer import (
@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
96
96
  def update_wrapper(
97
97
  self,
98
98
  forward_batch: ForwardBatch,
99
+ disable_flashinfer_ragged: bool = False,
99
100
  ):
100
101
  assert forward_batch.num_prefix_chunks is not None
101
102
  num_prefix_chunks = forward_batch.num_prefix_chunks
@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
128
129
  causal=False,
129
130
  )
130
131
  # ragged prefill
131
- self.ragged_wrapper.begin_forward(
132
- qo_indptr=qo_indptr,
133
- kv_indptr=qo_indptr,
134
- num_qo_heads=self.num_local_heads,
135
- num_kv_heads=self.num_local_heads,
136
- head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
137
- head_dim_vo=self.v_head_dim,
138
- q_data_type=self.q_data_type,
139
- causal=True,
140
- )
132
+ if not disable_flashinfer_ragged:
133
+ self.ragged_wrapper.begin_forward(
134
+ qo_indptr=qo_indptr,
135
+ kv_indptr=qo_indptr,
136
+ num_qo_heads=self.num_local_heads,
137
+ num_kv_heads=self.num_local_heads,
138
+ head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
139
+ head_dim_vo=self.v_head_dim,
140
+ q_data_type=self.q_data_type,
141
+ causal=True,
142
+ )
141
143
 
142
144
  def forward(
143
145
  self,
@@ -359,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
359
361
  seq_lens: torch.Tensor,
360
362
  encoder_lens: Optional[torch.Tensor],
361
363
  forward_mode: ForwardMode,
362
- spec_info: Optional[SpecInfo],
364
+ spec_info: Optional[SpecInput],
363
365
  ):
364
366
  if forward_mode.is_decode_or_idle():
365
367
  decode_wrapper = BatchMLAPagedAttentionWrapper(
@@ -439,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
439
441
  seq_lens_sum: int,
440
442
  encoder_lens: Optional[torch.Tensor],
441
443
  forward_mode: ForwardMode,
442
- spec_info: Optional[SpecInfo],
444
+ spec_info: Optional[SpecInput],
443
445
  seq_lens_cpu: Optional[torch.Tensor],
444
446
  ):
445
447
  if forward_mode.is_decode_or_idle():
@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
491
493
  def get_cuda_graph_seq_len_fill_value(self):
492
494
  return 1
493
495
 
494
- def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
496
+ def init_mha_chunk_metadata(
497
+ self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False
498
+ ):
495
499
  """Init the metadata for a forward pass."""
496
- self.mha_chunk_kv_cache.update_wrapper(forward_batch)
500
+ self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
497
501
 
498
502
  def forward_extend(
499
503
  self,
@@ -659,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
659
663
  seq_lens_sum: int,
660
664
  decode_wrapper: BatchMLAPagedAttentionWrapper,
661
665
  init_metadata_replay: bool = False,
662
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
666
+ spec_info: Optional[SpecInput] = None,
663
667
  **fast_decode_kwargs,
664
668
  ):
665
669
  decode_wrapper = decode_wrapper or self.decode_wrapper
@@ -684,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
684
688
  q_indptr: torch.Tensor,
685
689
  kv_indptr: torch.Tensor,
686
690
  init_metadata_replay: bool = False,
687
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
691
+ spec_info: Optional[SpecInput] = None,
688
692
  **fast_decode_kwargs,
689
693
  ):
690
694
  bs = len(req_pool_indices)
@@ -772,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
772
776
  prefix_lens: torch.Tensor,
773
777
  prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
774
778
  use_ragged: bool,
775
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
779
+ spec_info: Optional[SpecInput] = None,
776
780
  ):
777
781
  if use_ragged:
778
782
  paged_kernel_lens = prefix_lens
@@ -807,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
807
811
  kv_indptr: torch.Tensor,
808
812
  qo_indptr: torch.Tensor,
809
813
  use_ragged: bool,
810
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
814
+ spec_info: Optional[SpecInput] = None,
811
815
  ):
812
816
  bs = len(seq_lens)
813
817
  sm_scale = self.scaling
@@ -834,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
834
838
  qo_indptr = qo_indptr[: bs + 1]
835
839
  custom_mask = None
836
840
  else:
837
- assert isinstance(spec_info, EagleDraftInput) or isinstance(
838
- spec_info, EagleVerifyInput
839
- )
841
+ assert isinstance(spec_info, SpecInput)
840
842
  # TODO: Support topk > 1 with custom mask
841
843
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
842
844
  spec_info.generate_attn_arg_prefill(
@@ -890,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
890
892
  topk: int,
891
893
  speculative_num_steps: int,
892
894
  ):
893
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
895
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
894
896
 
895
897
  if topk > 1:
896
898
  raise ValueError(
@@ -959,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
959
961
  )
960
962
 
961
963
  assert forward_batch.spec_info is not None
962
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
964
+ assert forward_batch.spec_info.is_draft_input()
963
965
 
964
966
  for i in range(self.speculative_num_steps - 1):
965
967
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
@@ -979,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
979
981
  )
980
982
 
981
983
  def call_fn(i, forward_batch):
982
- assert forward_batch.spec_info is not None
983
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
984
984
  forward_batch.spec_info.kv_indptr = (
985
985
  forward_batch.spec_info.kv_indptr.clone()
986
986
  )