sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -5,21 +5,21 @@ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import torch_npu
8
- from torch.nn.functional import scaled_dot_product_attention
9
8
 
10
9
  from sglang.srt.configs.model_config import AttentionArch
11
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
+ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
13
  from sglang.srt.layers.dp_attention import get_attention_tp_size
14
14
  from sglang.srt.layers.radix_attention import AttentionType
15
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
16
+ from sglang.srt.speculative.spec_info import SpecInput
16
17
  from sglang.srt.utils import get_bool_env_var
17
18
 
18
19
  if TYPE_CHECKING:
19
20
  from sglang.srt.layers.radix_attention import RadixAttention
20
21
  from sglang.srt.model_executor.model_runner import ModelRunner
21
22
 
22
- import os
23
23
 
24
24
  import numpy as np
25
25
 
@@ -35,6 +35,8 @@ class ForwardMetadata:
35
35
  seq_lens_cpu_int: Optional[torch.Tensor] = None
36
36
  seq_lens_cpu_list: Optional[List[int]] = None
37
37
  seq_lens_list_cumsum: Optional[List[int]] = None
38
+ seq_lens: Optional[torch.Tensor] = None
39
+ actual_seq_lengths_q: Optional[torch.Tensor] = None
38
40
 
39
41
 
40
42
  class AscendAttnBackend(AttentionBackend):
@@ -66,6 +68,9 @@ class AscendAttnBackend(AttentionBackend):
66
68
  if self.use_mla:
67
69
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
68
70
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
71
+ self.q_head_dim = (
72
+ self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
73
+ )
69
74
  self.native_attn = TorchNativeAttnBackend(model_runner)
70
75
  self.graph_metadata = {}
71
76
  self.max_context_len = model_runner.model_config.context_len
@@ -101,10 +106,6 @@ class AscendAttnBackend(AttentionBackend):
101
106
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
102
107
 
103
108
  seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
104
- if forward_batch.is_extend_in_batch:
105
- seq_lens_list_cumsum[-1] = (
106
- (seq_lens_list_cumsum[-1] - 1) // tp_size + 1
107
- ) * tp_size
108
109
  self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
109
110
 
110
111
  self.graph_mode = False
@@ -126,12 +127,16 @@ class AscendAttnBackend(AttentionBackend):
126
127
  seq_lens: torch.Tensor,
127
128
  encoder_lens: Optional[torch.Tensor],
128
129
  forward_mode: ForwardMode,
129
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
130
+ spec_info: Optional[SpecInput],
130
131
  ):
131
132
  metadata = ForwardMetadata()
132
133
 
133
134
  metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
134
135
  metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
136
+ metadata.seq_lens = seq_lens
137
+ metadata.actual_seq_lengths_q = torch.tensor(
138
+ [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
139
+ )
135
140
 
136
141
  self.graph_metadata[bs] = metadata
137
142
  self.forward_metadata = metadata
@@ -146,7 +151,7 @@ class AscendAttnBackend(AttentionBackend):
146
151
  seq_lens_sum: int,
147
152
  encoder_lens: Optional[torch.Tensor],
148
153
  forward_mode: ForwardMode,
149
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
154
+ spec_info: Optional[SpecInput],
150
155
  seq_lens_cpu: Optional[torch.Tensor],
151
156
  ):
152
157
  metadata = self.graph_metadata[bs]
@@ -160,6 +165,8 @@ class AscendAttnBackend(AttentionBackend):
160
165
  metadata.block_tables[:bs, max_seq_pages:].fill_(0)
161
166
  metadata.block_tables[bs:, :].fill_(0)
162
167
 
168
+ metadata.seq_lens[:bs].copy_(seq_lens[:bs])
169
+
163
170
  self.forward_metadata = metadata
164
171
 
165
172
  self.graph_mode = True
@@ -167,6 +174,64 @@ class AscendAttnBackend(AttentionBackend):
167
174
  def get_cuda_graph_seq_len_fill_value(self):
168
175
  return 0
169
176
 
177
+ def forward_sparse(
178
+ self,
179
+ q: torch.Tensor,
180
+ k: torch.Tensor,
181
+ v: torch.Tensor,
182
+ layer: RadixAttention,
183
+ forward_batch: ForwardBatch,
184
+ save_kv_cache: bool = True,
185
+ # For multi_head latent attention
186
+ q_rope: Optional[torch.Tensor] = None,
187
+ k_rope: Optional[torch.Tensor] = None,
188
+ topk_indices: torch.Tensor = None,
189
+ ):
190
+
191
+ is_prefill = forward_batch.forward_mode.is_extend()
192
+
193
+ if save_kv_cache:
194
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
195
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
196
+ forward_batch.token_to_kv_pool.set_kv_buffer(
197
+ layer, forward_batch.out_cache_loc, k, k_rope
198
+ )
199
+ q_nope, q_pe = q, q_rope
200
+ k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
201
+ block_table = self.forward_metadata.block_tables
202
+ if is_prefill:
203
+ actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
204
+ else:
205
+ if self.forward_metadata.actual_seq_lengths_q is None:
206
+ actual_seq_qlen = (
207
+ torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
208
+ )
209
+ else:
210
+ actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
211
+ if self.forward_metadata.seq_lens_cpu_int is None:
212
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens
213
+ else:
214
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
215
+
216
+ attn_out = torch.ops.custom.npu_sparse_flash_attention(
217
+ query=q_nope,
218
+ key=k_nope,
219
+ value=k_nope,
220
+ query_rope=q_pe,
221
+ key_rope=k_pe,
222
+ sparse_indices=topk_indices,
223
+ scale_value=layer.scaling,
224
+ actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
225
+ actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
226
+ block_table=block_table,
227
+ sparse_block_size=1,
228
+ layout_query="TND",
229
+ layout_kv="PA_BSND",
230
+ sparse_mode=3,
231
+ )
232
+
233
+ return attn_out
234
+
170
235
  def forward_extend(
171
236
  self,
172
237
  q,
@@ -175,7 +240,23 @@ class AscendAttnBackend(AttentionBackend):
175
240
  layer: RadixAttention,
176
241
  forward_batch: ForwardBatch,
177
242
  save_kv_cache: bool = True,
243
+ # For multi_head latent attention
244
+ q_rope: Optional[torch.Tensor] = None,
245
+ k_rope: Optional[torch.Tensor] = None,
246
+ topk_indices: Optional[torch.Tensor] = None,
178
247
  ):
248
+ if topk_indices is not None:
249
+ return self.forward_sparse(
250
+ q,
251
+ k,
252
+ v,
253
+ layer,
254
+ forward_batch,
255
+ save_kv_cache,
256
+ q_rope,
257
+ k_rope,
258
+ topk_indices,
259
+ )
179
260
  if not self.use_mla:
180
261
  if save_kv_cache:
181
262
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -274,6 +355,11 @@ class AscendAttnBackend(AttentionBackend):
274
355
  assert (
275
356
  layer.qk_head_dim != layer.v_head_dim
276
357
  ), "FIA only supports qk_head_dim != v_head_dim"
358
+ num_token_padding = q.shape[0]
359
+ q, k, v = [
360
+ data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v]
361
+ ]
362
+
277
363
  q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
278
364
  k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
279
365
 
@@ -293,6 +379,18 @@ class AscendAttnBackend(AttentionBackend):
293
379
  next_tokens=0,
294
380
  )
295
381
 
382
+ attn_output = attn_output.reshape(-1, layer.tp_q_head_num, layer.v_head_dim)
383
+ if num_token_padding != forward_batch.num_token_non_padded_cpu:
384
+ attn_output = torch.cat(
385
+ [
386
+ attn_output,
387
+ attn_output.new_zeros(
388
+ num_token_padding - attn_output.shape[0],
389
+ *attn_output.shape[1:],
390
+ ),
391
+ ],
392
+ dim=0,
393
+ )
296
394
  return attn_output
297
395
 
298
396
  def forward_decode_graph(
@@ -401,7 +499,7 @@ class AscendAttnBackend(AttentionBackend):
401
499
  antiquant_scale=None,
402
500
  sparse_mode=0,
403
501
  )
404
- output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
502
+ output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
405
503
  softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
406
504
 
407
505
  torch_npu.npu_fused_infer_attention_score.out(
@@ -436,7 +534,24 @@ class AscendAttnBackend(AttentionBackend):
436
534
  # For multi-head latent attention
437
535
  q_rope: Optional[torch.Tensor] = None,
438
536
  k_rope: Optional[torch.Tensor] = None,
537
+ topk_indices: Optional[torch.Tensor] = None,
439
538
  ):
539
+ if is_mla_preprocess_enabled():
540
+ # MLAPO does saving kv_cache
541
+ save_kv_cache = False
542
+ if topk_indices is not None:
543
+ return self.forward_sparse(
544
+ q,
545
+ k,
546
+ v,
547
+ layer,
548
+ forward_batch,
549
+ save_kv_cache,
550
+ q_rope,
551
+ k_rope,
552
+ topk_indices,
553
+ )
554
+
440
555
  if self.graph_mode:
441
556
  return self.forward_decode_graph(
442
557
  q,
@@ -0,0 +1,226 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ if TYPE_CHECKING:
8
+ # evade circular imports
9
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
10
+ from sglang.srt.model_executor.model_runner import ModelRunner
11
+
12
+ ATTENTION_BACKENDS = {}
13
+
14
+
15
+ def register_attention_backend(name):
16
+ def decorator(fn):
17
+ ATTENTION_BACKENDS[name] = fn
18
+ return fn
19
+
20
+ return decorator
21
+
22
+
23
+ @register_attention_backend("flashinfer")
24
+ def create_flashinfer_backend(runner):
25
+ import torch
26
+
27
+ if not runner.use_mla_backend:
28
+ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
29
+
30
+ # Init streams
31
+ if runner.server_args.speculative_algorithm == "EAGLE":
32
+ if (
33
+ not hasattr(runner, "plan_stream_for_flashinfer")
34
+ or not runner.plan_stream_for_flashinfer
35
+ ):
36
+ runner.plan_stream_for_flashinfer = torch.cuda.Stream()
37
+ return FlashInferAttnBackend(
38
+ runner, init_new_workspace=runner.init_new_workspace
39
+ )
40
+ else:
41
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
42
+ FlashInferMLAAttnBackend,
43
+ )
44
+
45
+ return FlashInferMLAAttnBackend(runner)
46
+
47
+
48
+ @register_attention_backend("trtllm_mla")
49
+ def create_trtllm_mla_backend(runner):
50
+ if not runner.use_mla_backend:
51
+ raise ValueError("trtllm_mla backend can only be used with MLA models.")
52
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
53
+
54
+ return TRTLLMMLABackend(runner)
55
+
56
+
57
+ @register_attention_backend("aiter")
58
+ def create_aiter_backend(runner):
59
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
60
+
61
+ return AiterAttnBackend(runner)
62
+
63
+
64
+ @register_attention_backend("wave")
65
+ def create_wave_backend(runner):
66
+ from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
67
+
68
+ return WaveAttnBackend(runner)
69
+
70
+
71
+ @register_attention_backend("ascend")
72
+ def create_ascend_backend(runner):
73
+ from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
74
+
75
+ return AscendAttnBackend(runner)
76
+
77
+
78
+ @register_attention_backend("nsa")
79
+ def create_nsa_backend(runner):
80
+ from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
81
+
82
+ return NativeSparseAttnBackend(runner)
83
+
84
+
85
+ @register_attention_backend("triton")
86
+ def create_triton_backend(runner):
87
+ assert not runner.model_config.is_encoder_decoder, (
88
+ "Cross attention is not supported in the triton attention backend. "
89
+ "Please use `--attention-backend flashinfer`."
90
+ )
91
+ if runner.server_args.enable_double_sparsity:
92
+ from sglang.srt.layers.attention.double_sparsity_backend import (
93
+ DoubleSparseAttnBackend,
94
+ )
95
+
96
+ return DoubleSparseAttnBackend(runner)
97
+ else:
98
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
99
+
100
+ return TritonAttnBackend(runner)
101
+
102
+
103
+ @register_attention_backend("torch_native")
104
+ def create_torch_native_backend(runner):
105
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
106
+
107
+ return TorchNativeAttnBackend(runner)
108
+
109
+
110
+ @register_attention_backend("flex_attention")
111
+ def create_flex_attention_backend(runner):
112
+ from sglang.srt.layers.attention.torch_flex_backend import TorchFlexAttnBackend
113
+
114
+ return TorchFlexAttnBackend(runner)
115
+
116
+
117
+ @register_attention_backend("flashmla")
118
+ def create_flashmla_backend(runner):
119
+ from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
120
+
121
+ return FlashMLABackend(runner)
122
+
123
+
124
+ @register_attention_backend("fa3")
125
+ def create_flashattention_v3_backend(runner):
126
+ import torch
127
+
128
+ assert (
129
+ torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend
130
+ ) or torch.cuda.get_device_capability()[0] == 9, (
131
+ "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
132
+ "Please use `--attention-backend flashinfer`."
133
+ )
134
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
135
+
136
+ return FlashAttentionBackend(runner)
137
+
138
+
139
+ @register_attention_backend("fa4")
140
+ def create_flashattention_v4_backend(runner):
141
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
142
+
143
+ return FlashAttentionBackend(runner, fa_impl_ver=4)
144
+
145
+
146
+ @register_attention_backend("cutlass_mla")
147
+ def create_cutlass_mla_backend(runner):
148
+ from sglang.srt.layers.attention.cutlass_mla_backend import CutlassMLABackend
149
+
150
+ return CutlassMLABackend(runner)
151
+
152
+
153
+ @register_attention_backend("trtllm_mha")
154
+ def create_trtllm_mha_backend(runner):
155
+ if runner.use_mla_backend:
156
+ raise ValueError("trtllm_mha backend can only be used with non-MLA models.")
157
+ from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
158
+
159
+ return TRTLLMHAAttnBackend(runner)
160
+
161
+
162
+ @register_attention_backend("intel_amx")
163
+ def create_intel_amx_backend(runner):
164
+ from sglang.srt.layers.attention.intel_amx_backend import IntelAMXAttnBackend
165
+
166
+ return IntelAMXAttnBackend(runner)
167
+
168
+
169
+ @register_attention_backend("dual_chunk_flash_attn")
170
+ def create_dual_chunk_flash_attn_backend(runner):
171
+ from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
172
+ DualChunkFlashAttentionBackend,
173
+ )
174
+
175
+ return DualChunkFlashAttentionBackend(runner)
176
+
177
+
178
+ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
179
+ """
180
+ Wrapper for special models like hybrid GDN, so we don't
181
+ need to change the code of the original attention backend.
182
+ """
183
+ assert not (
184
+ runner.hybrid_gdn_config is not None and runner.use_mla_backend
185
+ ), "hybrid_gdn can only be used with non-MLA models."
186
+
187
+ if cfg := runner.mambaish_config:
188
+ from sglang.srt.layers.attention.fla.utils import check_environments
189
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
190
+ GDNAttnBackend,
191
+ HybridLinearAttnBackend,
192
+ Mamba2AttnBackend,
193
+ )
194
+ from sglang.srt.utils import is_blackwell, is_npu
195
+
196
+ check_environments()
197
+ if runner.hybrid_gdn_config is not None:
198
+ if is_blackwell():
199
+ assert (
200
+ runner.server_args.attention_backend == "triton"
201
+ ), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
202
+ if is_npu():
203
+ assert (
204
+ runner.server_args.attention_backend == "ascend"
205
+ ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
206
+ logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
207
+ linear_attn_backend = GDNAttnBackend(runner)
208
+ elif runner.mamba2_config is not None:
209
+ linear_attn_backend = Mamba2AttnBackend(runner)
210
+ else:
211
+ raise ValueError(
212
+ "Expected hybrid GDN or NemotronH models, but got unknown model."
213
+ )
214
+ full_attn_layers = cfg.full_attention_layer_ids
215
+ return HybridLinearAttnBackend(
216
+ full_attn_backend, linear_attn_backend, full_attn_layers
217
+ )
218
+
219
+ return full_attn_backend
220
+
221
+
222
+ @register_attention_backend("intel_xpu")
223
+ def create_intel_xpu_backend(runner):
224
+ from sglang.srt.layers.attention.xpu_backend import XPUAttentionBackend
225
+
226
+ return XPUAttentionBackend(runner)
@@ -1,14 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, Optional, Union
4
+ from typing import TYPE_CHECKING, Optional
5
5
 
6
6
  import torch
7
7
 
8
8
  if TYPE_CHECKING:
9
+ from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
9
10
  from sglang.srt.layers.radix_attention import RadixAttention
10
11
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
11
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
12
+ from sglang.srt.speculative.spec_info import SpecInput
12
13
 
13
14
 
14
15
  class AttentionBackend(ABC):
@@ -31,7 +32,7 @@ class AttentionBackend(ABC):
31
32
  seq_lens: torch.Tensor,
32
33
  encoder_lens: Optional[torch.Tensor],
33
34
  forward_mode: ForwardMode,
34
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
35
+ spec_info: Optional[SpecInput],
35
36
  ):
36
37
  """Init the metadata for a forward pass for capturing a cuda graph."""
37
38
  raise NotImplementedError()
@@ -44,7 +45,7 @@ class AttentionBackend(ABC):
44
45
  seq_lens_sum: int,
45
46
  encoder_lens: Optional[torch.Tensor],
46
47
  forward_mode: ForwardMode,
47
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
48
+ spec_info: Optional[SpecInput],
48
49
  seq_lens_cpu: Optional[torch.Tensor],
49
50
  ):
50
51
  """Init the metadata for a forward pass for replaying a cuda graph."""
@@ -54,6 +55,25 @@ class AttentionBackend(ABC):
54
55
  """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
55
56
  raise NotImplementedError()
56
57
 
58
+ def get_verify_buffers_to_fill_after_draft(self):
59
+ """
60
+ Return buffers of verify attention kernels that needs to be filled after draft.
61
+
62
+ Typically, these are tree mask and position buffers.
63
+ """
64
+ return [None, None]
65
+
66
+ def update_verify_buffers_to_fill_after_draft(
67
+ self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
68
+ ):
69
+ """
70
+ Update the buffers returned by get_verify_fill_after_draft_buffers if needed.
71
+
72
+ Here, we need to redo the computation of all metadata of the attention backend
73
+ that depends on tree mask and position buffers.
74
+ """
75
+ raise NotImplementedError()
76
+
57
77
  def forward(
58
78
  self,
59
79
  q: torch.Tensor,
@@ -115,3 +135,11 @@ class AttentionBackend(ABC):
115
135
  def support_triton(self):
116
136
  """Check if the current backend supports triton."""
117
137
  return True
138
+
139
+ def get_indexer_metadata(
140
+ self,
141
+ layer_id: int,
142
+ forward_batch: ForwardBatch,
143
+ ) -> Optional[BaseIndexerMetadata]:
144
+ """Get the indexer metadata. None means don't support indexer."""
145
+ return None
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
20
20
  if TYPE_CHECKING:
21
21
  from sglang.srt.layers.radix_attention import RadixAttention
22
22
  from sglang.srt.model_executor.model_runner import ModelRunner
23
- from sglang.srt.speculative.spec_info import SpecInfo
23
+ from sglang.srt.speculative.spec_info import SpecInput
24
24
 
25
25
  _is_cuda = is_cuda()
26
26
  if _is_cuda:
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
151
151
  seq_lens: torch.Tensor,
152
152
  encoder_lens: Optional[torch.Tensor],
153
153
  forward_mode: ForwardMode,
154
- spec_info: Optional[SpecInfo],
154
+ spec_info: Optional[SpecInput],
155
155
  ):
156
156
  if forward_mode.is_decode_or_idle():
157
157
  if spec_info is None:
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
190
190
  seq_lens_sum: int,
191
191
  encoder_lens: Optional[torch.Tensor],
192
192
  forward_mode: ForwardMode,
193
- spec_info: Optional[SpecInfo],
193
+ spec_info: Optional[SpecInput],
194
194
  seq_lens_cpu: Optional[torch.Tensor],
195
195
  ):
196
196
 
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
5
5
  import torch
6
6
 
7
7
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
8
- from sglang.srt.managers.schedule_batch import global_server_args_dict
9
8
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
9
+ from sglang.srt.server_args import get_global_server_args
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
42
42
  # TODO: Change the hard-coded block_seq_num
43
43
  self.BLOCK_SEQ = 128
44
44
 
45
- if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
45
+ if get_global_server_args().triton_attention_reduce_in_fp32:
46
46
  self.reduce_dtype = torch.float32
47
47
  else:
48
48
  self.reduce_dtype = torch.float16
@@ -1537,7 +1537,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
1537
1537
  query_inter,
1538
1538
  key_cache,
1539
1539
  value_cache,
1540
- block_table[:, : decode_meta.max_seq_len_inter],
1540
+ block_table,
1541
1541
  decode_meta.seq_lens_inter,
1542
1542
  softmax_scale,
1543
1543
  causal=False,
@@ -2,7 +2,6 @@
2
2
  # -*- coding: utf-8 -*-
3
3
  # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
4
 
5
- import warnings
6
5
  from typing import Optional
7
6
 
8
7
  import torch
@@ -2,7 +2,7 @@
2
2
  # -*- coding: utf-8 -*-
3
3
  # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
4
 
5
- from typing import Optional, Tuple
5
+ from typing import Optional
6
6
 
7
7
  import torch
8
8
  import triton
@@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
74
74
  (1, 0),
75
75
  )
76
76
  b_k = tl.load(p_k, boundary_check=(0, 1))
77
- b_kb = b_k * b_beta[:, None]
78
- b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
77
+ b_A += tl.dot(b_k, tl.trans(b_k))
79
78
 
80
79
  if USE_G:
81
80
  p_g = tl.make_block_ptr(
@@ -85,6 +84,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
85
84
  b_g_diff = b_g[:, None] - b_g[None, :]
86
85
  b_A = b_A * safe_exp(b_g_diff)
87
86
 
87
+ b_A *= b_beta[:, None]
88
88
  b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
89
89
  p_A = tl.make_block_ptr(
90
90
  A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
86
86
  b_g = tl.load(p_g).to(tl.float32)
87
87
 
88
88
  if USE_QK_L2NORM_IN_KERNEL:
89
- b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
90
- b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
89
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
90
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
91
91
  b_q = b_q * scale
92
92
  # [BK, BV]
93
93
  b_h *= exp(b_g)
@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
411
411
  b_g = tl.load(p_g).to(tl.float32)
412
412
 
413
413
  if USE_QK_L2NORM_IN_KERNEL:
414
- b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
415
- b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
414
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
415
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
416
416
  b_q = b_q * scale
417
417
  # [BK, BV]
418
418
  b_h *= exp(b_g)
@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
119
119
 
120
120
  # Apply L2 normalization if enabled
121
121
  if USE_QK_L2NORM_IN_KERNEL:
122
- b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
123
- b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
122
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
123
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
124
124
 
125
125
  b_q = b_q * scale
126
126
 
@@ -3,9 +3,7 @@
3
3
  # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
4
 
5
5
  import torch
6
- import torch.nn.functional as F
7
6
  import triton
8
- import triton.language as tl
9
7
 
10
8
  from sglang.srt.layers.attention.fla.utils import tensor_cache
11
9