sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,10 @@
1
- from typing import TYPE_CHECKING, Callable, List, Optional, Union
1
+ from typing import TYPE_CHECKING, Callable, List, Optional
2
2
 
3
3
  import torch
4
4
 
5
5
  from sglang.srt import two_batch_overlap
6
6
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
7
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
7
+ from sglang.srt.speculative.spec_info import SpecInput
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
46
46
  seq_lens: torch.Tensor,
47
47
  encoder_lens: Optional[torch.Tensor],
48
48
  forward_mode: "ForwardMode",
49
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
49
+ spec_info: Optional[SpecInput],
50
50
  ):
51
51
  self.primary.init_forward_metadata_capture_cuda_graph(
52
52
  bs=bs,
@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
77
77
  seq_lens_sum: int,
78
78
  encoder_lens: Optional[torch.Tensor],
79
79
  forward_mode: "ForwardMode",
80
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
80
+ spec_info: Optional[SpecInput],
81
81
  seq_lens_cpu: Optional[torch.Tensor],
82
82
  ):
83
83
  self.primary.init_forward_metadata_replay_cuda_graph(
@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
112
112
  seq_lens: torch.Tensor,
113
113
  encoder_lens: Optional[torch.Tensor],
114
114
  forward_mode: "ForwardMode",
115
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
115
+ spec_info: Optional[SpecInput],
116
116
  # capture args
117
117
  capture_num_tokens: int = None,
118
118
  # replay args
@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
196
196
  seq_lens: torch.Tensor,
197
197
  encoder_lens: Optional[torch.Tensor],
198
198
  forward_mode: "ForwardMode",
199
- spec_info: Optional[EagleVerifyInput],
199
+ spec_info: Optional[SpecInput],
200
200
  # capture args
201
201
  capture_num_tokens: int = None,
202
202
  # replay args
@@ -0,0 +1,325 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
7
+
8
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
+ from sglang.srt.layers.radix_attention import AttentionType
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.layers.radix_attention import RadixAttention
14
+ from sglang.srt.model_executor.model_runner import ModelRunner
15
+
16
+
17
+ class TorchFlexAttnBackend(AttentionBackend):
18
+ def __init__(self, model_runner: ModelRunner):
19
+ super().__init__()
20
+ self.forward_metadata = None
21
+ self.device = model_runner.device
22
+ self.flex_attention = torch.compile(flex_attention, dynamic=True)
23
+ torch._dynamo.config.cache_size_limit = 1024
24
+ torch._dynamo.config.accumulated_cache_size_limit = 1024
25
+
26
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
27
+ """Init the metadata for a forward pass."""
28
+ # TODO: find a more elegant way to save memory
29
+ # Currently maintain the same memory as torch_native_backend
30
+ torch.cuda.empty_cache()
31
+
32
+ # Provide two block_mask Lists per seq_idx for lower latency, later will support per layer level mask generation
33
+ self.extend_block_masks = []
34
+ self.decode_block_masks = []
35
+
36
+ if forward_batch.forward_mode.is_extend():
37
+ for seq_idx in range(forward_batch.seq_lens.shape[0]):
38
+ seq_len_kv = forward_batch.seq_lens[seq_idx]
39
+ seq_len_q = seq_len_kv
40
+ self.extend_block_masks.append(
41
+ create_block_mask(
42
+ self._causal_mask,
43
+ None,
44
+ None,
45
+ seq_len_q,
46
+ seq_len_kv,
47
+ device=self.device,
48
+ _compile=False,
49
+ )
50
+ )
51
+
52
+ elif forward_batch.forward_mode.is_decode():
53
+ for seq_idx in range(forward_batch.seq_lens.shape[0]):
54
+ seq_len_q = 1
55
+ seq_len_kv = forward_batch.seq_lens[seq_idx]
56
+
57
+ self.decode_block_masks.append(
58
+ create_block_mask(
59
+ self._decode_mask,
60
+ None,
61
+ None,
62
+ seq_len_q,
63
+ seq_len_kv,
64
+ device=self.device,
65
+ _compile=False,
66
+ )
67
+ )
68
+
69
+ def _causal_mask(self, b, h, q_idx, kv_idx):
70
+ return q_idx >= kv_idx
71
+
72
+ def _decode_mask(self, b, h, q_idx, kv_idx):
73
+ return q_idx <= kv_idx
74
+
75
+ def _run_flex_forward_extend(
76
+ self,
77
+ query: torch.Tensor,
78
+ output: torch.Tensor,
79
+ k_cache: torch.Tensor,
80
+ v_cache: torch.Tensor,
81
+ req_to_token: torch.Tensor,
82
+ req_pool_indices: torch.Tensor,
83
+ seq_lens: torch.Tensor,
84
+ extend_prefix_lens: torch.Tensor,
85
+ extend_seq_lens: torch.Tensor,
86
+ scaling=None,
87
+ enable_gqa=False,
88
+ causal=False,
89
+ ):
90
+ """Run the extend forward by using torch flex attention op.
91
+
92
+ Args:
93
+ query: [num_tokens, num_heads, head_size]
94
+ output: [num_tokens, num_heads, head_size]
95
+ k_cache: [max_total_num_tokens, num_heads, head_size]
96
+ v_cache: [max_total_num_tokens, num_heads, head_size]
97
+ req_to_token: [max_num_reqs, max_context_len]
98
+ req_pool_indices: [num_seqs]
99
+ seq_lens: [num_seqs]
100
+ extend_prefix_lens: [num_seqs]
101
+ extend_seq_lens: [num_seqs]
102
+ scaling: float or None
103
+ enable_gqa: bool
104
+ causal: bool
105
+
106
+ Returns:
107
+ output: [num_tokens, num_heads, head_size]
108
+ """
109
+
110
+ assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
111
+ assert seq_lens.shape[0] == extend_seq_lens.shape[0]
112
+
113
+ # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
114
+ query = query.movedim(0, query.dim() - 2)
115
+
116
+ start_q, start_kv = 0, 0
117
+
118
+ for seq_idx in range(seq_lens.shape[0]):
119
+ # TODO: this loop process a sequence per iter, this is inefficient.
120
+ # Need optimize the performance later.
121
+ extend_seq_len_q = extend_seq_lens[seq_idx]
122
+ prefill_seq_len_q = extend_prefix_lens[seq_idx]
123
+
124
+ seq_len_kv = seq_lens[seq_idx]
125
+ end_q = start_q + extend_seq_len_q
126
+ end_kv = start_kv + seq_len_kv
127
+
128
+ per_req_query = query[:, start_q:end_q, :]
129
+ per_req_query_redundant = torch.empty(
130
+ (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
131
+ dtype=per_req_query.dtype,
132
+ device=per_req_query.device,
133
+ )
134
+
135
+ per_req_query_redundant[:, prefill_seq_len_q:, :] = per_req_query
136
+
137
+ # get key and value from cache. per_req_tokens contains the kv cache
138
+ # index for each token in the sequence.
139
+ req_pool_idx = req_pool_indices[seq_idx]
140
+ per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
141
+ per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
142
+ per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
143
+
144
+ if not causal:
145
+ raise NotImplementedError("Non-causal mode is not yet implemented.")
146
+
147
+ per_req_out_redundant = (
148
+ self.flex_attention(
149
+ per_req_query_redundant.unsqueeze(0),
150
+ per_req_key.unsqueeze(0),
151
+ per_req_value.unsqueeze(0),
152
+ block_mask=self.extend_block_masks[seq_idx],
153
+ scale=scaling,
154
+ enable_gqa=enable_gqa,
155
+ )
156
+ .squeeze(0)
157
+ .movedim(query.dim() - 2, 0)
158
+ )
159
+ output[start_q:end_q, :, :] = per_req_out_redundant[
160
+ prefill_seq_len_q:, :, :
161
+ ]
162
+ start_q, start_kv = end_q, end_kv
163
+ return output
164
+
165
+ def _run_flex_forward_decode(
166
+ self,
167
+ query: torch.Tensor,
168
+ output: torch.Tensor,
169
+ k_cache: torch.Tensor,
170
+ v_cache: torch.Tensor,
171
+ req_to_token: torch.Tensor,
172
+ req_pool_indices: torch.Tensor,
173
+ seq_lens: torch.Tensor,
174
+ scaling=None,
175
+ enable_gqa=False,
176
+ causal=False,
177
+ ):
178
+ """Run the decode forward by using torch flex attention op.
179
+
180
+ Args:
181
+ query: [num_tokens, num_heads, head_size]
182
+ output: [num_tokens, num_heads, head_size]
183
+ k_cache: [max_total_num_tokens, num_heads, head_size]
184
+ v_cache: [max_total_num_tokens, num_heads, head_size]
185
+ req_to_token: [max_num_reqs, max_context_len]
186
+ req_pool_indices: [num_seqs]
187
+ seq_lens: [num_seqs]
188
+ scaling: float or None
189
+ enable_gqa: bool
190
+ causal: bool
191
+
192
+ Returns:
193
+ output: [num_tokens, num_heads, head_size]
194
+ """
195
+
196
+ # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
197
+ query = query.movedim(0, query.dim() - 2)
198
+
199
+ start_q, start_kv = 0, 0
200
+ for seq_idx in range(seq_lens.shape[0]):
201
+ # TODO: this loop process a sequence per iter, this is inefficient.
202
+ # Need optimize the performance later.
203
+
204
+ seq_len_q = 1
205
+ seq_len_kv = seq_lens[seq_idx]
206
+ end_q = start_q + seq_len_q
207
+ end_kv = start_kv + seq_len_kv
208
+
209
+ per_req_query = query[:, start_q:end_q, :]
210
+
211
+ # get key and value from cache. per_req_tokens contains the kv cache
212
+ # index for each token in the sequence.
213
+ req_pool_idx = req_pool_indices[seq_idx]
214
+ per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
215
+ per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
216
+ per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
217
+
218
+ per_req_out = (
219
+ self.flex_attention(
220
+ per_req_query.unsqueeze(0),
221
+ per_req_key.unsqueeze(0),
222
+ per_req_value.unsqueeze(0),
223
+ block_mask=self.decode_block_masks[seq_idx],
224
+ scale=scaling,
225
+ enable_gqa=enable_gqa,
226
+ )
227
+ .squeeze(0)
228
+ .movedim(query.dim() - 2, 0)
229
+ )
230
+
231
+ output[start_q:end_q, :, :] = per_req_out
232
+ start_q, start_kv = end_q, end_kv
233
+
234
+ return output
235
+
236
+ def forward_extend(
237
+ self,
238
+ q,
239
+ k,
240
+ v,
241
+ layer: RadixAttention,
242
+ forward_batch: ForwardBatch,
243
+ save_kv_cache=True,
244
+ ):
245
+ if layer.qk_head_dim != layer.v_head_dim:
246
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
247
+ else:
248
+ o = torch.empty_like(q)
249
+
250
+ if save_kv_cache:
251
+ forward_batch.token_to_kv_pool.set_kv_buffer(
252
+ layer, forward_batch.out_cache_loc, k, v
253
+ )
254
+
255
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
256
+
257
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
258
+ o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
259
+
260
+ causal = True
261
+ if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
262
+ raise NotImplementedError(
263
+ "TorchFlexAttnBackend does not support non-causal attention for now."
264
+ )
265
+
266
+ self._run_flex_forward_extend(
267
+ q_,
268
+ o_,
269
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
270
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
271
+ forward_batch.req_to_token_pool.req_to_token,
272
+ forward_batch.req_pool_indices,
273
+ forward_batch.seq_lens,
274
+ forward_batch.extend_prefix_lens,
275
+ forward_batch.extend_seq_lens,
276
+ scaling=layer.scaling,
277
+ enable_gqa=use_gqa,
278
+ causal=causal,
279
+ )
280
+ return o
281
+
282
+ def forward_decode(
283
+ self,
284
+ q,
285
+ k,
286
+ v,
287
+ layer: RadixAttention,
288
+ forward_batch: ForwardBatch,
289
+ save_kv_cache=True,
290
+ ):
291
+ # During torch.compile, there is a bug in rotary_emb that causes the
292
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
293
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
294
+
295
+ if layer.qk_head_dim != layer.v_head_dim:
296
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
297
+ else:
298
+ o = torch.empty_like(q)
299
+
300
+ if save_kv_cache:
301
+ forward_batch.token_to_kv_pool.set_kv_buffer(
302
+ layer, forward_batch.out_cache_loc, k, v
303
+ )
304
+
305
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
306
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
307
+ o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
308
+
309
+ self._run_flex_forward_decode(
310
+ q_,
311
+ o_,
312
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
313
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
314
+ forward_batch.req_to_token_pool.req_to_token,
315
+ forward_batch.req_pool_indices,
316
+ forward_batch.seq_lens,
317
+ scaling=layer.scaling,
318
+ enable_gqa=use_gqa,
319
+ causal=False,
320
+ )
321
+
322
+ return o
323
+
324
+ def support_triton(self):
325
+ return False
@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
193
193
  else:
194
194
  o = torch.empty_like(q)
195
195
 
196
+ if layer.is_cross_attention:
197
+ cache_loc = forward_batch.encoder_out_cache_loc
198
+ else:
199
+ cache_loc = forward_batch.out_cache_loc
200
+
196
201
  if save_kv_cache:
197
- forward_batch.token_to_kv_pool.set_kv_buffer(
198
- layer, forward_batch.out_cache_loc, k, v
199
- )
202
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
200
203
 
201
204
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
202
205
 
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
241
244
  else:
242
245
  o = torch.empty_like(q)
243
246
 
247
+ if layer.is_cross_attention:
248
+ cache_loc = forward_batch.encoder_out_cache_loc
249
+ else:
250
+ cache_loc = forward_batch.out_cache_loc
251
+
244
252
  if save_kv_cache:
245
- forward_batch.token_to_kv_pool.set_kv_buffer(
246
- layer, forward_batch.out_cache_loc, k, v
247
- )
253
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
248
254
 
249
255
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
250
256
 
@@ -12,12 +12,17 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
12
12
  from sglang.srt.layers.dp_attention import get_attention_tp_size
13
13
  from sglang.srt.layers.radix_attention import AttentionType
14
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
- from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
15
+ from sglang.srt.utils import (
16
+ get_bool_env_var,
17
+ get_device_core_count,
18
+ get_int_env_var,
19
+ next_power_of_2,
20
+ )
16
21
 
17
22
  if TYPE_CHECKING:
18
23
  from sglang.srt.layers.radix_attention import RadixAttention
19
24
  from sglang.srt.model_executor.model_runner import ModelRunner
20
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
25
+ from sglang.srt.speculative.spec_info import SpecInput
21
26
 
22
27
 
23
28
  def logit_capping_mod(logit_capping_method, logit_cap):
@@ -80,7 +85,13 @@ class TritonAttnBackend(AttentionBackend):
80
85
  self.num_kv_head = model_runner.model_config.get_num_kv_heads(
81
86
  get_attention_tp_size()
82
87
  )
83
- self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
88
+ if model_runner.is_hybrid_gdn:
89
+ # For hybrid linear models, layer_id = 0 may not be full attention
90
+ self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
91
+ else:
92
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
93
+ -1
94
+ ]
84
95
  self.max_context_len = model_runner.model_config.context_len
85
96
  self.device = model_runner.device
86
97
  self.device_core_count = get_device_core_count(model_runner.gpu_id)
@@ -89,6 +100,29 @@ class TritonAttnBackend(AttentionBackend):
89
100
  )
90
101
  self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
91
102
 
103
+ # Decide whether enable deterministic inference with batch-invariant operations
104
+ self.enable_deterministic = (
105
+ model_runner.server_args.enable_deterministic_inference
106
+ )
107
+
108
+ # Configure deterministic inference settings
109
+ if self.enable_deterministic:
110
+ # Use fixed split tile size for batch invariance
111
+ self.split_tile_size = get_int_env_var(
112
+ "SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
113
+ )
114
+ # Set static_kv_splits to False to use deterministic logic instead
115
+ self.static_kv_splits = False
116
+ else:
117
+ self.split_tile_size = (
118
+ model_runner.server_args.triton_attention_split_tile_size
119
+ )
120
+
121
+ if self.split_tile_size is not None:
122
+ self.max_kv_splits = (
123
+ self.max_context_len + self.split_tile_size - 1
124
+ ) // self.split_tile_size
125
+
92
126
  # Check arguments
93
127
  assert not (
94
128
  model_runner.sliding_window_size is not None
@@ -143,10 +177,26 @@ class TritonAttnBackend(AttentionBackend):
143
177
  num_group * num_seq == num_token
144
178
  ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
145
179
 
146
- if self.static_kv_splits or self.device_core_count <= 0:
180
+ # Legacy dynamic splitting logic (non-deterministic)
181
+ if (
182
+ self.static_kv_splits or self.device_core_count <= 0
183
+ ) and not self.enable_deterministic:
147
184
  num_kv_splits.fill_(self.max_kv_splits)
148
185
  return
149
186
 
187
+ # deterministic
188
+ if self.split_tile_size is not None and self.enable_deterministic:
189
+ # expand seq_lens to match num_token
190
+ if num_group > 1:
191
+ expanded_seq_lens = seq_lens.repeat_interleave(num_group)
192
+ else:
193
+ expanded_seq_lens = seq_lens
194
+
195
+ num_kv_splits[:] = (
196
+ expanded_seq_lens + self.split_tile_size - 1
197
+ ) // self.split_tile_size
198
+ return
199
+
150
200
  if num_seq < 256:
151
201
  SCHEDULE_SEQ = 256
152
202
  else:
@@ -432,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
432
482
  seq_lens: torch.Tensor,
433
483
  encoder_lens: Optional[torch.Tensor],
434
484
  forward_mode: ForwardMode,
435
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
485
+ spec_info: Optional[SpecInput],
436
486
  ):
437
487
  assert encoder_lens is None, "Not supported"
438
488
  window_kv_indptr = self.window_kv_indptr
@@ -588,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
588
638
  seq_lens_sum: int,
589
639
  encoder_lens: Optional[torch.Tensor],
590
640
  forward_mode: ForwardMode,
591
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
641
+ spec_info: Optional[SpecInput],
592
642
  seq_lens_cpu: Optional[torch.Tensor],
593
643
  ):
594
644
  # NOTE: encoder_lens expected to be zeros or None
@@ -833,7 +883,7 @@ class TritonMultiStepDraftBackend:
833
883
  topk: int,
834
884
  speculative_num_steps: int,
835
885
  ):
836
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
886
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
837
887
 
838
888
  self.topk = topk
839
889
  self.speculative_num_steps = speculative_num_steps
@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
20
20
  if is_flashinfer_available():
21
21
  import flashinfer
22
22
 
23
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
24
-
25
23
  if TYPE_CHECKING:
26
24
  from sglang.srt.layers.radix_attention import RadixAttention
27
25
  from sglang.srt.model_executor.model_runner import ModelRunner
28
- from sglang.srt.speculative.spec_info import SpecInfo
26
+ from sglang.srt.speculative.spec_info import SpecInput
29
27
 
30
28
  # Constants
31
29
  DEFAULT_WORKSPACE_SIZE_MB = (
@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
201
199
  seq_lens: torch.Tensor,
202
200
  encoder_lens: Optional[torch.Tensor],
203
201
  forward_mode: ForwardMode,
204
- spec_info: Optional[SpecInfo],
202
+ spec_info: Optional[SpecInput],
205
203
  ):
206
204
  """Initialize metadata for CUDA graph capture."""
207
205
  metadata = TRTLLMMHAMetadata()
@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
314
312
  seq_lens_sum: int,
315
313
  encoder_lens: Optional[torch.Tensor],
316
314
  forward_mode: ForwardMode,
317
- spec_info: Optional[SpecInfo],
315
+ spec_info: Optional[SpecInput],
318
316
  seq_lens_cpu: Optional[torch.Tensor],
319
317
  ):
320
318
  """Replay CUDA graph with new inputs."""
@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
661
659
  forward_batch: ForwardBatch,
662
660
  ):
663
661
  assert forward_batch.spec_info is not None
664
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
662
+ assert forward_batch.spec_info.is_draft_input()
665
663
 
666
664
  for i in range(self.speculative_num_steps - 1):
667
665
  self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
678
676
  self, forward_batch: ForwardBatch, bs: int
679
677
  ):
680
678
  assert forward_batch.spec_info is not None
681
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
679
+ assert forward_batch.spec_info.is_draft_input()
682
680
 
683
681
  for i in range(self.speculative_num_steps - 1):
684
682