sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -5,13 +5,15 @@ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import torch_npu
8
- from torch.nn.functional import scaled_dot_product_attention
9
8
 
10
9
  from sglang.srt.configs.model_config import AttentionArch
11
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
+ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
13
14
  from sglang.srt.layers.radix_attention import AttentionType
14
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
16
+ from sglang.srt.speculative.spec_info import SpecInput
15
17
  from sglang.srt.utils import get_bool_env_var
16
18
 
17
19
  if TYPE_CHECKING:
@@ -33,6 +35,9 @@ class ForwardMetadata:
33
35
  extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
34
36
  seq_lens_cpu_int: Optional[torch.Tensor] = None
35
37
  seq_lens_cpu_list: Optional[List[int]] = None
38
+ seq_lens_list_cumsum: Optional[List[int]] = None
39
+ seq_lens: Optional[torch.Tensor] = None
40
+ actual_seq_lengths_q: Optional[torch.Tensor] = None
36
41
 
37
42
 
38
43
  class AscendAttnBackend(AttentionBackend):
@@ -64,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
64
69
  if self.use_mla:
65
70
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
66
71
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
72
+ self.q_head_dim = (
73
+ self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
74
+ )
67
75
  self.native_attn = TorchNativeAttnBackend(model_runner)
68
76
  self.graph_metadata = {}
69
77
  self.max_context_len = model_runner.model_config.context_len
@@ -83,6 +91,7 @@ class AscendAttnBackend(AttentionBackend):
83
91
 
84
92
  def init_forward_metadata(self, forward_batch: ForwardBatch):
85
93
  """Init the metadata for a forward pass."""
94
+ tp_size = get_attention_tp_size()
86
95
  self.forward_metadata = ForwardMetadata()
87
96
 
88
97
  self.forward_metadata.block_tables = (
@@ -96,9 +105,9 @@ class AscendAttnBackend(AttentionBackend):
96
105
  forward_batch.extend_seq_lens.cpu().int()
97
106
  )
98
107
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
99
- self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
100
- forward_batch.extend_seq_lens_cpu
101
- )
108
+
109
+ seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
110
+ self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
102
111
 
103
112
  self.graph_mode = False
104
113
 
@@ -119,12 +128,16 @@ class AscendAttnBackend(AttentionBackend):
119
128
  seq_lens: torch.Tensor,
120
129
  encoder_lens: Optional[torch.Tensor],
121
130
  forward_mode: ForwardMode,
122
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
131
+ spec_info: Optional[SpecInput],
123
132
  ):
124
133
  metadata = ForwardMetadata()
125
134
 
126
135
  metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
127
136
  metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
137
+ metadata.seq_lens = seq_lens
138
+ metadata.actual_seq_lengths_q = torch.tensor(
139
+ [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
140
+ )
128
141
 
129
142
  self.graph_metadata[bs] = metadata
130
143
  self.forward_metadata = metadata
@@ -139,7 +152,7 @@ class AscendAttnBackend(AttentionBackend):
139
152
  seq_lens_sum: int,
140
153
  encoder_lens: Optional[torch.Tensor],
141
154
  forward_mode: ForwardMode,
142
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
155
+ spec_info: Optional[SpecInput],
143
156
  seq_lens_cpu: Optional[torch.Tensor],
144
157
  ):
145
158
  metadata = self.graph_metadata[bs]
@@ -153,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
153
166
  metadata.block_tables[:bs, max_seq_pages:].fill_(0)
154
167
  metadata.block_tables[bs:, :].fill_(0)
155
168
 
169
+ metadata.seq_lens[:bs].copy_(seq_lens[:bs])
170
+
156
171
  self.forward_metadata = metadata
157
172
 
158
173
  self.graph_mode = True
@@ -160,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
160
175
  def get_cuda_graph_seq_len_fill_value(self):
161
176
  return 0
162
177
 
178
+ def forward_sparse(
179
+ self,
180
+ q: torch.Tensor,
181
+ k: torch.Tensor,
182
+ v: torch.Tensor,
183
+ layer: RadixAttention,
184
+ forward_batch: ForwardBatch,
185
+ save_kv_cache: bool = True,
186
+ # For multi_head latent attention
187
+ q_rope: Optional[torch.Tensor] = None,
188
+ k_rope: Optional[torch.Tensor] = None,
189
+ topk_indices: torch.Tensor = None,
190
+ ):
191
+
192
+ is_prefill = forward_batch.forward_mode.is_extend()
193
+
194
+ if save_kv_cache:
195
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
196
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
197
+ forward_batch.token_to_kv_pool.set_kv_buffer(
198
+ layer, forward_batch.out_cache_loc, k, k_rope
199
+ )
200
+ q_nope, q_pe = q, q_rope
201
+ k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
202
+ block_table = self.forward_metadata.block_tables
203
+ if is_prefill:
204
+ actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
205
+ else:
206
+ if self.forward_metadata.actual_seq_lengths_q is None:
207
+ actual_seq_qlen = (
208
+ torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
209
+ )
210
+ else:
211
+ actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
212
+ if self.forward_metadata.seq_lens_cpu_int is None:
213
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens
214
+ else:
215
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
216
+
217
+ attn_out = torch.ops.custom.npu_sparse_flash_attention(
218
+ query=q_nope,
219
+ key=k_nope,
220
+ value=k_nope,
221
+ query_rope=q_pe,
222
+ key_rope=k_pe,
223
+ sparse_indices=topk_indices,
224
+ scale_value=layer.scaling,
225
+ actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
226
+ actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
227
+ block_table=block_table,
228
+ sparse_block_size=1,
229
+ layout_query="TND",
230
+ layout_kv="PA_BSND",
231
+ sparse_mode=3,
232
+ )
233
+
234
+ return attn_out
235
+
163
236
  def forward_extend(
164
237
  self,
165
238
  q,
@@ -168,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
168
241
  layer: RadixAttention,
169
242
  forward_batch: ForwardBatch,
170
243
  save_kv_cache: bool = True,
244
+ # For multi_head latent attention
245
+ q_rope: Optional[torch.Tensor] = None,
246
+ k_rope: Optional[torch.Tensor] = None,
247
+ topk_indices: Optional[torch.Tensor] = None,
171
248
  ):
249
+ if topk_indices is not None:
250
+ return self.forward_sparse(
251
+ q,
252
+ k,
253
+ v,
254
+ layer,
255
+ forward_batch,
256
+ save_kv_cache,
257
+ q_rope,
258
+ k_rope,
259
+ topk_indices,
260
+ )
172
261
  if not self.use_mla:
173
262
  if save_kv_cache:
174
263
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -368,7 +457,7 @@ class AscendAttnBackend(AttentionBackend):
368
457
  -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
369
458
  )
370
459
 
371
- q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
460
+ q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
372
461
  q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
373
462
  if self.forward_metadata.seq_lens_cpu_int is None:
374
463
  actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
@@ -394,7 +483,7 @@ class AscendAttnBackend(AttentionBackend):
394
483
  antiquant_scale=None,
395
484
  sparse_mode=0,
396
485
  )
397
- output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
486
+ output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
398
487
  softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
399
488
 
400
489
  torch_npu.npu_fused_infer_attention_score.out(
@@ -429,7 +518,24 @@ class AscendAttnBackend(AttentionBackend):
429
518
  # For multi-head latent attention
430
519
  q_rope: Optional[torch.Tensor] = None,
431
520
  k_rope: Optional[torch.Tensor] = None,
521
+ topk_indices: Optional[torch.Tensor] = None,
432
522
  ):
523
+ if is_mla_preprocess_enabled():
524
+ # MLAPO does saving kv_cache
525
+ save_kv_cache = False
526
+ if topk_indices is not None:
527
+ return self.forward_sparse(
528
+ q,
529
+ k,
530
+ v,
531
+ layer,
532
+ forward_batch,
533
+ save_kv_cache,
534
+ q_rope,
535
+ k_rope,
536
+ topk_indices,
537
+ )
538
+
433
539
  if self.graph_mode:
434
540
  return self.forward_decode_graph(
435
541
  q,
@@ -0,0 +1,215 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ if TYPE_CHECKING:
8
+ # evade circular imports
9
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
10
+ from sglang.srt.model_executor.model_runner import ModelRunner
11
+
12
+ ATTENTION_BACKENDS = {}
13
+
14
+
15
+ def register_attention_backend(name):
16
+ def decorator(fn):
17
+ ATTENTION_BACKENDS[name] = fn
18
+ return fn
19
+
20
+ return decorator
21
+
22
+
23
+ @register_attention_backend("flashinfer")
24
+ def create_flashinfer_backend(runner):
25
+ import torch
26
+
27
+ if not runner.use_mla_backend:
28
+ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
29
+
30
+ # Init streams
31
+ if runner.server_args.speculative_algorithm == "EAGLE":
32
+ if (
33
+ not hasattr(runner, "plan_stream_for_flashinfer")
34
+ or not runner.plan_stream_for_flashinfer
35
+ ):
36
+ runner.plan_stream_for_flashinfer = torch.cuda.Stream()
37
+ return FlashInferAttnBackend(runner)
38
+ else:
39
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
40
+ FlashInferMLAAttnBackend,
41
+ )
42
+
43
+ return FlashInferMLAAttnBackend(runner)
44
+
45
+
46
+ @register_attention_backend("trtllm_mla")
47
+ def create_trtllm_mla_backend(runner):
48
+ if not runner.use_mla_backend:
49
+ raise ValueError("trtllm_mla backend can only be used with MLA models.")
50
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
51
+
52
+ return TRTLLMMLABackend(runner)
53
+
54
+
55
+ @register_attention_backend("aiter")
56
+ def create_aiter_backend(runner):
57
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
58
+
59
+ return AiterAttnBackend(runner)
60
+
61
+
62
+ @register_attention_backend("wave")
63
+ def create_wave_backend(runner):
64
+ from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
65
+
66
+ return WaveAttnBackend(runner)
67
+
68
+
69
+ @register_attention_backend("ascend")
70
+ def create_ascend_backend(runner):
71
+ from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
72
+
73
+ return AscendAttnBackend(runner)
74
+
75
+
76
+ @register_attention_backend("nsa")
77
+ def create_nsa_backend(runner):
78
+ from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
79
+
80
+ return NativeSparseAttnBackend(runner)
81
+
82
+
83
+ @register_attention_backend("triton")
84
+ def create_triton_backend(runner):
85
+ assert not runner.model_config.is_encoder_decoder, (
86
+ "Cross attention is not supported in the triton attention backend. "
87
+ "Please use `--attention-backend flashinfer`."
88
+ )
89
+ if runner.server_args.enable_double_sparsity:
90
+ from sglang.srt.layers.attention.double_sparsity_backend import (
91
+ DoubleSparseAttnBackend,
92
+ )
93
+
94
+ return DoubleSparseAttnBackend(runner)
95
+ else:
96
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
97
+
98
+ return TritonAttnBackend(runner)
99
+
100
+
101
+ @register_attention_backend("torch_native")
102
+ def create_torch_native_backend(runner):
103
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
104
+
105
+ return TorchNativeAttnBackend(runner)
106
+
107
+
108
+ @register_attention_backend("flex_attention")
109
+ def create_flex_attention_backend(runner):
110
+ from sglang.srt.layers.attention.torch_flex_backend import TorchFlexAttnBackend
111
+
112
+ return TorchFlexAttnBackend(runner)
113
+
114
+
115
+ @register_attention_backend("flashmla")
116
+ def create_flashmla_backend(runner):
117
+ from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
118
+
119
+ return FlashMLABackend(runner)
120
+
121
+
122
+ @register_attention_backend("fa3")
123
+ def create_flashattention_v3_backend(runner):
124
+ import torch
125
+
126
+ assert (
127
+ torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend
128
+ ) or torch.cuda.get_device_capability()[0] == 9, (
129
+ "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
130
+ "Please use `--attention-backend flashinfer`."
131
+ )
132
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
133
+
134
+ return FlashAttentionBackend(runner)
135
+
136
+
137
+ @register_attention_backend("fa4")
138
+ def create_flashattention_v4_backend(runner):
139
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
140
+
141
+ return FlashAttentionBackend(runner, fa_impl_ver=4)
142
+
143
+
144
+ @register_attention_backend("cutlass_mla")
145
+ def create_cutlass_mla_backend(runner):
146
+ from sglang.srt.layers.attention.cutlass_mla_backend import CutlassMLABackend
147
+
148
+ return CutlassMLABackend(runner)
149
+
150
+
151
+ @register_attention_backend("trtllm_mha")
152
+ def create_trtllm_mha_backend(runner):
153
+ if runner.use_mla_backend:
154
+ raise ValueError("trtllm_mha backend can only be used with non-MLA models.")
155
+ from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
156
+
157
+ return TRTLLMHAAttnBackend(runner)
158
+
159
+
160
+ @register_attention_backend("intel_amx")
161
+ def create_intel_amx_backend(runner):
162
+ from sglang.srt.layers.attention.intel_amx_backend import IntelAMXAttnBackend
163
+
164
+ return IntelAMXAttnBackend(runner)
165
+
166
+
167
+ @register_attention_backend("dual_chunk_flash_attn")
168
+ def create_dual_chunk_flash_attn_backend(runner):
169
+ from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
170
+ DualChunkFlashAttentionBackend,
171
+ )
172
+
173
+ return DualChunkFlashAttentionBackend(runner)
174
+
175
+
176
+ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
177
+ """
178
+ Wrapper for special models like hybrid GDN, so we don't
179
+ need to change the code of the original attention backend.
180
+ """
181
+ assert not (
182
+ runner.hybrid_gdn_config is not None and runner.use_mla_backend
183
+ ), "hybrid_gdn can only be used with non-MLA models."
184
+
185
+ if cfg := runner.mambaish_config:
186
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
187
+ GDNAttnBackend,
188
+ HybridLinearAttnBackend,
189
+ Mamba2AttnBackend,
190
+ )
191
+ from sglang.srt.utils import is_blackwell, is_npu
192
+
193
+ if runner.hybrid_gdn_config is not None:
194
+ if is_blackwell():
195
+ assert (
196
+ runner.server_args.attention_backend == "triton"
197
+ ), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
198
+ if is_npu():
199
+ assert (
200
+ runner.server_args.attention_backend == "ascend"
201
+ ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
202
+ logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
203
+ linear_attn_backend = GDNAttnBackend(runner)
204
+ elif runner.mamba2_config is not None:
205
+ linear_attn_backend = Mamba2AttnBackend(runner)
206
+ else:
207
+ raise ValueError(
208
+ "Expected hybrid GDN or NemotronH models, but got unknown model."
209
+ )
210
+ full_attn_layers = cfg.full_attention_layer_ids
211
+ return HybridLinearAttnBackend(
212
+ full_attn_backend, linear_attn_backend, full_attn_layers
213
+ )
214
+
215
+ return full_attn_backend
@@ -6,9 +6,10 @@ from typing import TYPE_CHECKING, Optional, Union
6
6
  import torch
7
7
 
8
8
  if TYPE_CHECKING:
9
+ from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
9
10
  from sglang.srt.layers.radix_attention import RadixAttention
10
11
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
11
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
12
+ from sglang.srt.speculative.spec_info import SpecInput
12
13
 
13
14
 
14
15
  class AttentionBackend(ABC):
@@ -31,7 +32,7 @@ class AttentionBackend(ABC):
31
32
  seq_lens: torch.Tensor,
32
33
  encoder_lens: Optional[torch.Tensor],
33
34
  forward_mode: ForwardMode,
34
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
35
+ spec_info: Optional[SpecInput],
35
36
  ):
36
37
  """Init the metadata for a forward pass for capturing a cuda graph."""
37
38
  raise NotImplementedError()
@@ -44,7 +45,7 @@ class AttentionBackend(ABC):
44
45
  seq_lens_sum: int,
45
46
  encoder_lens: Optional[torch.Tensor],
46
47
  forward_mode: ForwardMode,
47
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
48
+ spec_info: Optional[SpecInput],
48
49
  seq_lens_cpu: Optional[torch.Tensor],
49
50
  ):
50
51
  """Init the metadata for a forward pass for replaying a cuda graph."""
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
115
116
  def support_triton(self):
116
117
  """Check if the current backend supports triton."""
117
118
  return True
119
+
120
+ def get_indexer_metadata(
121
+ self,
122
+ layer_id: int,
123
+ forward_batch: ForwardBatch,
124
+ ) -> Optional[BaseIndexerMetadata]:
125
+ """Get the indexer metadata. None means don't support indexer."""
126
+ return None
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
20
20
  if TYPE_CHECKING:
21
21
  from sglang.srt.layers.radix_attention import RadixAttention
22
22
  from sglang.srt.model_executor.model_runner import ModelRunner
23
- from sglang.srt.speculative.spec_info import SpecInfo
23
+ from sglang.srt.speculative.spec_info import SpecInput
24
24
 
25
25
  _is_cuda = is_cuda()
26
26
  if _is_cuda:
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
151
151
  seq_lens: torch.Tensor,
152
152
  encoder_lens: Optional[torch.Tensor],
153
153
  forward_mode: ForwardMode,
154
- spec_info: Optional[SpecInfo],
154
+ spec_info: Optional[SpecInput],
155
155
  ):
156
156
  if forward_mode.is_decode_or_idle():
157
157
  if spec_info is None:
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
190
190
  seq_lens_sum: int,
191
191
  encoder_lens: Optional[torch.Tensor],
192
192
  forward_mode: ForwardMode,
193
- spec_info: Optional[SpecInfo],
193
+ spec_info: Optional[SpecInput],
194
194
  seq_lens_cpu: Optional[torch.Tensor],
195
195
  ):
196
196
 
@@ -1537,7 +1537,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
1537
1537
  query_inter,
1538
1538
  key_cache,
1539
1539
  value_cache,
1540
- block_table[:, : decode_meta.max_seq_len_inter],
1540
+ block_table,
1541
1541
  decode_meta.seq_lens_inter,
1542
1542
  softmax_scale,
1543
1543
  causal=False,