sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,206 @@
1
+ import logging
2
+
3
+ logger = logging.getLogger(__name__)
4
+
5
+ ATTENTION_BACKENDS = {}
6
+
7
+
8
+ def register_attention_backend(name):
9
+ def decorator(fn):
10
+ ATTENTION_BACKENDS[name] = fn
11
+ return fn
12
+
13
+ return decorator
14
+
15
+
16
+ @register_attention_backend("flashinfer")
17
+ def create_flashinfer_backend(runner):
18
+ import torch
19
+
20
+ if not runner.use_mla_backend:
21
+ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
22
+
23
+ # Init streams
24
+ if runner.server_args.speculative_algorithm == "EAGLE":
25
+ if (
26
+ not hasattr(runner, "plan_stream_for_flashinfer")
27
+ or not runner.plan_stream_for_flashinfer
28
+ ):
29
+ runner.plan_stream_for_flashinfer = torch.cuda.Stream()
30
+ return FlashInferAttnBackend(runner)
31
+ else:
32
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
33
+ FlashInferMLAAttnBackend,
34
+ )
35
+
36
+ return FlashInferMLAAttnBackend(runner)
37
+
38
+
39
+ @register_attention_backend("trtllm_mla")
40
+ def create_trtllm_mla_backend(runner):
41
+ if not runner.use_mla_backend:
42
+ raise ValueError("trtllm_mla backend can only be used with MLA models.")
43
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
44
+
45
+ return TRTLLMMLABackend(runner)
46
+
47
+
48
+ @register_attention_backend("aiter")
49
+ def create_aiter_backend(runner):
50
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
51
+
52
+ return AiterAttnBackend(runner)
53
+
54
+
55
+ @register_attention_backend("wave")
56
+ def create_wave_backend(runner):
57
+ from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
58
+
59
+ return WaveAttnBackend(runner)
60
+
61
+
62
+ @register_attention_backend("ascend")
63
+ def create_ascend_backend(runner):
64
+ from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
65
+
66
+ return AscendAttnBackend(runner)
67
+
68
+
69
+ @register_attention_backend("nsa")
70
+ def create_nsa_backend(runner):
71
+ from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
72
+
73
+ return NativeSparseAttnBackend(runner)
74
+
75
+
76
+ @register_attention_backend("triton")
77
+ def create_triton_backend(runner):
78
+ assert not runner.model_config.is_encoder_decoder, (
79
+ "Cross attention is not supported in the triton attention backend. "
80
+ "Please use `--attention-backend flashinfer`."
81
+ )
82
+ if runner.server_args.enable_double_sparsity:
83
+ from sglang.srt.layers.attention.double_sparsity_backend import (
84
+ DoubleSparseAttnBackend,
85
+ )
86
+
87
+ return DoubleSparseAttnBackend(runner)
88
+ else:
89
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
90
+
91
+ return TritonAttnBackend(runner)
92
+
93
+
94
+ @register_attention_backend("torch_native")
95
+ def create_torch_native_backend(runner):
96
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
97
+
98
+ return TorchNativeAttnBackend(runner)
99
+
100
+
101
+ @register_attention_backend("flex_attention")
102
+ def create_flex_attention_backend(runner):
103
+ from sglang.srt.layers.attention.torch_flex_backend import TorchFlexAttnBackend
104
+
105
+ return TorchFlexAttnBackend(runner)
106
+
107
+
108
+ @register_attention_backend("flashmla")
109
+ def create_flashmla_backend(runner):
110
+ from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
111
+
112
+ return FlashMLABackend(runner)
113
+
114
+
115
+ @register_attention_backend("fa3")
116
+ def create_flashattention_v3_backend(runner):
117
+ import torch
118
+
119
+ assert (
120
+ torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend
121
+ ) or torch.cuda.get_device_capability()[0] == 9, (
122
+ "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
123
+ "Please use `--attention-backend flashinfer`."
124
+ )
125
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
126
+
127
+ return FlashAttentionBackend(runner)
128
+
129
+
130
+ @register_attention_backend("fa4")
131
+ def create_flashattention_v4_backend(runner):
132
+ assert (
133
+ runner.use_mla_backend
134
+ ), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
135
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
136
+
137
+ return FlashAttentionBackend(runner, fa_impl_ver=4)
138
+
139
+
140
+ @register_attention_backend("cutlass_mla")
141
+ def create_cutlass_mla_backend(runner):
142
+ from sglang.srt.layers.attention.cutlass_mla_backend import CutlassMLABackend
143
+
144
+ return CutlassMLABackend(runner)
145
+
146
+
147
+ @register_attention_backend("trtllm_mha")
148
+ def create_trtllm_mha_backend(runner):
149
+ if runner.use_mla_backend:
150
+ raise ValueError("trtllm_mha backend can only be used with non-MLA models.")
151
+ from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
152
+
153
+ return TRTLLMHAAttnBackend(runner)
154
+
155
+
156
+ @register_attention_backend("intel_amx")
157
+ def create_intel_amx_backend(runner):
158
+ from sglang.srt.layers.attention.intel_amx_backend import IntelAMXAttnBackend
159
+
160
+ return IntelAMXAttnBackend(runner)
161
+
162
+
163
+ @register_attention_backend("dual_chunk_flash_attn")
164
+ def create_dual_chunk_flash_attn_backend(runner):
165
+ from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
166
+ DualChunkFlashAttentionBackend,
167
+ )
168
+
169
+ return DualChunkFlashAttentionBackend(runner)
170
+
171
+
172
+ def attn_backend_wrapper(runner, full_attn_backend):
173
+ """
174
+ Wrapper for special models like hybrid GDN, so we don't
175
+ need to change the code of the original attention backend.
176
+ """
177
+ assert not (
178
+ runner.is_hybrid_gdn and runner.use_mla_backend
179
+ ), "hybrid_gdn can only be used with non-MLA models."
180
+
181
+ # wrap for hybrid GDN models
182
+ if runner.is_hybrid_gdn:
183
+ from sglang.srt.utils import is_blackwell, is_npu
184
+
185
+ if is_blackwell():
186
+ assert (
187
+ runner.server_args.attention_backend == "triton"
188
+ or runner.server_args.attention_backend == "trtllm_mha"
189
+ ), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
190
+ if is_npu():
191
+ assert (
192
+ runner.server_args.attention_backend == "ascend"
193
+ ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
194
+ logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
195
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
196
+ HybridLinearAttnBackend,
197
+ MambaAttnBackend,
198
+ )
199
+
200
+ linear_attn_backend = MambaAttnBackend(runner)
201
+ full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
202
+ return HybridLinearAttnBackend(
203
+ full_attn_backend, linear_attn_backend, full_attn_layers
204
+ )
205
+
206
+ return full_attn_backend
@@ -6,9 +6,10 @@ from typing import TYPE_CHECKING, Optional, Union
6
6
  import torch
7
7
 
8
8
  if TYPE_CHECKING:
9
+ from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
9
10
  from sglang.srt.layers.radix_attention import RadixAttention
10
11
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
11
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
12
+ from sglang.srt.speculative.spec_info import SpecInput
12
13
 
13
14
 
14
15
  class AttentionBackend(ABC):
@@ -31,7 +32,7 @@ class AttentionBackend(ABC):
31
32
  seq_lens: torch.Tensor,
32
33
  encoder_lens: Optional[torch.Tensor],
33
34
  forward_mode: ForwardMode,
34
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
35
+ spec_info: Optional[SpecInput],
35
36
  ):
36
37
  """Init the metadata for a forward pass for capturing a cuda graph."""
37
38
  raise NotImplementedError()
@@ -44,7 +45,7 @@ class AttentionBackend(ABC):
44
45
  seq_lens_sum: int,
45
46
  encoder_lens: Optional[torch.Tensor],
46
47
  forward_mode: ForwardMode,
47
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
48
+ spec_info: Optional[SpecInput],
48
49
  seq_lens_cpu: Optional[torch.Tensor],
49
50
  ):
50
51
  """Init the metadata for a forward pass for replaying a cuda graph."""
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
115
116
  def support_triton(self):
116
117
  """Check if the current backend supports triton."""
117
118
  return True
119
+
120
+ def get_indexer_metadata(
121
+ self,
122
+ layer_id: int,
123
+ forward_batch: ForwardBatch,
124
+ ) -> Optional[BaseIndexerMetadata]:
125
+ """Get the indexer metadata. None means don't support indexer."""
126
+ return None
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
20
20
  if TYPE_CHECKING:
21
21
  from sglang.srt.layers.radix_attention import RadixAttention
22
22
  from sglang.srt.model_executor.model_runner import ModelRunner
23
- from sglang.srt.speculative.spec_info import SpecInfo
23
+ from sglang.srt.speculative.spec_info import SpecInput
24
24
 
25
25
  _is_cuda = is_cuda()
26
26
  if _is_cuda:
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
151
151
  seq_lens: torch.Tensor,
152
152
  encoder_lens: Optional[torch.Tensor],
153
153
  forward_mode: ForwardMode,
154
- spec_info: Optional[SpecInfo],
154
+ spec_info: Optional[SpecInput],
155
155
  ):
156
156
  if forward_mode.is_decode_or_idle():
157
157
  if spec_info is None:
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
190
190
  seq_lens_sum: int,
191
191
  encoder_lens: Optional[torch.Tensor],
192
192
  forward_mode: ForwardMode,
193
- spec_info: Optional[SpecInfo],
193
+ spec_info: Optional[SpecInput],
194
194
  seq_lens_cpu: Optional[torch.Tensor],
195
195
  ):
196
196
 
@@ -1537,7 +1537,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
1537
1537
  query_inter,
1538
1538
  key_cache,
1539
1539
  value_cache,
1540
- block_table[:, : decode_meta.max_seq_len_inter],
1540
+ block_table,
1541
1541
  decode_meta.seq_lens_inter,
1542
1542
  softmax_scale,
1543
1543
  causal=False,
@@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
74
74
  (1, 0),
75
75
  )
76
76
  b_k = tl.load(p_k, boundary_check=(0, 1))
77
- b_kb = b_k * b_beta[:, None]
78
- b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
77
+ b_A += tl.dot(b_k, tl.trans(b_k))
79
78
 
80
79
  if USE_G:
81
80
  p_g = tl.make_block_ptr(
@@ -85,6 +84,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
85
84
  b_g_diff = b_g[:, None] - b_g[None, :]
86
85
  b_A = b_A * safe_exp(b_g_diff)
87
86
 
87
+ b_A *= b_beta[:, None]
88
88
  b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
89
89
  p_A = tl.make_block_ptr(
90
90
  A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
@@ -86,8 +86,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
86
86
  b_g = tl.load(p_g).to(tl.float32)
87
87
 
88
88
  if USE_QK_L2NORM_IN_KERNEL:
89
- b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
90
- b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
89
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
90
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
91
91
  b_q = b_q * scale
92
92
  # [BK, BV]
93
93
  b_h *= exp(b_g)
@@ -411,8 +411,8 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
411
411
  b_g = tl.load(p_g).to(tl.float32)
412
412
 
413
413
  if USE_QK_L2NORM_IN_KERNEL:
414
- b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
415
- b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
414
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
415
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
416
416
  b_q = b_q * scale
417
417
  # [BK, BV]
418
418
  b_h *= exp(b_g)
@@ -119,8 +119,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
119
119
 
120
120
  # Apply L2 normalization if enabled
121
121
  if USE_QK_L2NORM_IN_KERNEL:
122
- b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
123
- b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
122
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
123
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
124
124
 
125
125
  b_q = b_q * scale
126
126
 
@@ -11,9 +11,8 @@ import triton.language as tl
11
11
  from sglang.srt.configs.model_config import AttentionArch
12
12
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
13
13
  from sglang.srt.managers.schedule_batch import global_server_args_dict
14
- from sglang.srt.mem_cache.memory_pool import SWAKVPool
15
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
16
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
15
+ from sglang.srt.speculative.spec_info import SpecInput
17
16
 
18
17
  if TYPE_CHECKING:
19
18
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -305,6 +304,7 @@ class FlashAttentionBackend(AttentionBackend):
305
304
  speculative_step_id=0,
306
305
  topk=0,
307
306
  speculative_num_steps=0,
307
+ fa_impl_ver=3,
308
308
  ):
309
309
  super().__init__()
310
310
 
@@ -338,6 +338,8 @@ class FlashAttentionBackend(AttentionBackend):
338
338
  )
339
339
  self.speculative_step_id = speculative_step_id
340
340
 
341
+ self.fa_impl_ver = fa_impl_ver
342
+
341
343
  # Local attention settings
342
344
  self.attention_chunk_size = (
343
345
  model_runner.attention_chunk_size
@@ -352,6 +354,13 @@ class FlashAttentionBackend(AttentionBackend):
352
354
  self.sliding_window_size is not None and self.sliding_window_size > -1
353
355
  )
354
356
 
357
+ # If num_splits == 0, we use a heuristic to automatically determine the number of splits.
358
+ # We set nums splits to 1 if deterministic inference is enabled.
359
+ # See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
360
+ self.num_splits = (
361
+ 1 if model_runner.server_args.enable_deterministic_inference else 0
362
+ )
363
+
355
364
  def init_forward_metadata(self, forward_batch: ForwardBatch):
356
365
  """Initialize forward metadata hence all layers in the forward pass can reuse it."""
357
366
  metadata = FlashAttentionMetadata()
@@ -682,8 +691,13 @@ class FlashAttentionBackend(AttentionBackend):
682
691
  k_descale, v_descale = None, None
683
692
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
684
693
  # has corresponding quantization method so that layer.k_scale is not None,
685
- # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
686
- if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
694
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
695
+ # 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
696
+ if (
697
+ self.kv_cache_dtype_str != "auto"
698
+ and layer.head_dim <= 256
699
+ and self.fa_impl_ver != 4
700
+ ):
687
701
  if layer.k_scale is not None:
688
702
  descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
689
703
  k_descale = layer.k_scale.expand(descale_shape)
@@ -712,6 +726,8 @@ class FlashAttentionBackend(AttentionBackend):
712
726
 
713
727
  # For fa3 interface version compatibility, we put new fields into conditional keyword args
714
728
  kwargs = {}
729
+ if self.fa_impl_ver != 3:
730
+ kwargs["ver"] = self.fa_impl_ver
715
731
  if sinks is not None:
716
732
  kwargs["sinks"] = sinks
717
733
 
@@ -738,6 +754,7 @@ class FlashAttentionBackend(AttentionBackend):
738
754
 
739
755
  # Use Flash Attention for prefill
740
756
  if not self.use_mla:
757
+ assert self.fa_impl_ver in [3], "Only FA3 support here"
741
758
  # Do multi-head attention
742
759
  key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
743
760
  layer.layer_id
@@ -770,6 +787,7 @@ class FlashAttentionBackend(AttentionBackend):
770
787
  k_descale=k_descale,
771
788
  v_descale=v_descale,
772
789
  return_softmax_lse=use_cascade_attn,
790
+ num_splits=self.num_splits,
773
791
  **kwargs,
774
792
  )
775
793
 
@@ -791,6 +809,7 @@ class FlashAttentionBackend(AttentionBackend):
791
809
  k_descale=k_descale,
792
810
  v_descale=v_descale,
793
811
  return_softmax_lse=True,
812
+ num_splits=self.num_splits,
794
813
  **kwargs,
795
814
  )
796
815
  o, _ = merge_state_v2_wrapper(
@@ -830,6 +849,7 @@ class FlashAttentionBackend(AttentionBackend):
830
849
  softmax_scale=layer.scaling,
831
850
  causal=False,
832
851
  return_softmax_lse=True,
852
+ **kwargs,
833
853
  )
834
854
  else:
835
855
  # MHA for extend part of sequence without attending prefix kv cache
@@ -844,6 +864,7 @@ class FlashAttentionBackend(AttentionBackend):
844
864
  softmax_scale=layer.scaling,
845
865
  causal=True,
846
866
  return_softmax_lse=forward_batch.mha_return_lse,
867
+ **kwargs,
847
868
  )
848
869
  if forward_batch.mha_return_lse:
849
870
  output, lse, *rest = output
@@ -851,6 +872,7 @@ class FlashAttentionBackend(AttentionBackend):
851
872
  return output, lse
852
873
  return output
853
874
  else:
875
+ assert self.fa_impl_ver in [3], "Only FA3 support here"
854
876
  # Do absorbed multi-latent attention
855
877
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
856
878
  layer.layer_id
@@ -892,6 +914,7 @@ class FlashAttentionBackend(AttentionBackend):
892
914
  k_descale=k_descale,
893
915
  v_descale=v_descale,
894
916
  return_softmax_lse=use_cascade_attn,
917
+ num_splits=self.num_splits,
895
918
  )
896
919
  if use_cascade_attn:
897
920
  o, softmax_lse, *rest = result
@@ -913,6 +936,7 @@ class FlashAttentionBackend(AttentionBackend):
913
936
  k_descale=k_descale,
914
937
  v_descale=v_descale,
915
938
  return_softmax_lse=True,
939
+ num_splits=self.num_splits,
916
940
  )
917
941
  )
918
942
  o, _ = merge_state_v2_wrapper(
@@ -939,6 +963,7 @@ class FlashAttentionBackend(AttentionBackend):
939
963
  k_rope: Optional[torch.Tensor] = None,
940
964
  sinks: Optional[torch.Tensor] = None,
941
965
  ) -> torch.Tensor:
966
+ assert self.fa_impl_ver in [3], "Only FA3 support decoding"
942
967
  if k is not None:
943
968
  assert v is not None
944
969
  if save_kv_cache:
@@ -985,6 +1010,8 @@ class FlashAttentionBackend(AttentionBackend):
985
1010
 
986
1011
  # For fa3 interface version compatibility, we put new fields into conditional keyword args
987
1012
  kwargs = {}
1013
+ if self.fa_impl_ver != 3:
1014
+ kwargs["ver"] = self.fa_impl_ver
988
1015
  if sinks is not None:
989
1016
  kwargs["sinks"] = sinks
990
1017
 
@@ -1030,6 +1057,7 @@ class FlashAttentionBackend(AttentionBackend):
1030
1057
  softcap=layer.logit_cap,
1031
1058
  k_descale=k_descale,
1032
1059
  v_descale=v_descale,
1060
+ num_splits=self.num_splits,
1033
1061
  **kwargs,
1034
1062
  )
1035
1063
  elif use_local_attn:
@@ -1049,6 +1077,7 @@ class FlashAttentionBackend(AttentionBackend):
1049
1077
  softcap=layer.logit_cap,
1050
1078
  k_descale=k_descale,
1051
1079
  v_descale=v_descale,
1080
+ num_splits=self.num_splits,
1052
1081
  **kwargs,
1053
1082
  )
1054
1083
  else:
@@ -1077,6 +1106,7 @@ class FlashAttentionBackend(AttentionBackend):
1077
1106
  k_descale=k_descale,
1078
1107
  v_descale=v_descale,
1079
1108
  return_softmax_lse=use_cascade_attn,
1109
+ num_splits=self.num_splits,
1080
1110
  **kwargs,
1081
1111
  )
1082
1112
  if use_cascade_attn:
@@ -1098,6 +1128,7 @@ class FlashAttentionBackend(AttentionBackend):
1098
1128
  k_descale=k_descale,
1099
1129
  v_descale=v_descale,
1100
1130
  return_softmax_lse=True,
1131
+ num_splits=self.num_splits,
1101
1132
  **kwargs,
1102
1133
  )
1103
1134
  )
@@ -1153,6 +1184,7 @@ class FlashAttentionBackend(AttentionBackend):
1153
1184
  k_descale=k_descale,
1154
1185
  v_descale=v_descale,
1155
1186
  return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
1187
+ num_splits=self.num_splits,
1156
1188
  )
1157
1189
  if use_cascade_attn:
1158
1190
  o, softmax_lse, *rest = result
@@ -1173,6 +1205,7 @@ class FlashAttentionBackend(AttentionBackend):
1173
1205
  k_descale=k_descale,
1174
1206
  v_descale=v_descale,
1175
1207
  return_softmax_lse=True,
1208
+ num_splits=self.num_splits,
1176
1209
  )
1177
1210
  o, _ = merge_state_v2(
1178
1211
  o,
@@ -1453,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend):
1453
1486
  seq_lens: torch.Tensor,
1454
1487
  encoder_lens: Optional[torch.Tensor],
1455
1488
  forward_mode: ForwardMode,
1456
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1489
+ spec_info: Optional[SpecInput],
1457
1490
  ):
1458
1491
  """Initialize forward metadata for capturing CUDA graph."""
1459
1492
  metadata = FlashAttentionMetadata()
@@ -1688,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend):
1688
1721
  seq_lens_sum: int,
1689
1722
  encoder_lens: Optional[torch.Tensor],
1690
1723
  forward_mode: ForwardMode,
1691
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1724
+ spec_info: Optional[SpecInput],
1692
1725
  seq_lens_cpu: Optional[torch.Tensor],
1693
1726
  out_cache_loc: Optional[torch.Tensor] = None,
1694
1727
  ):
@@ -2306,7 +2339,7 @@ class FlashAttentionMultiStepBackend:
2306
2339
  forward_batch: ForwardBatch,
2307
2340
  ):
2308
2341
  assert forward_batch.spec_info is not None
2309
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
2342
+ assert forward_batch.spec_info.is_draft_input()
2310
2343
 
2311
2344
  for i in range(self.speculative_num_steps - 1):
2312
2345
  self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
@@ -2323,7 +2356,7 @@ class FlashAttentionMultiStepBackend:
2323
2356
  self, forward_batch: ForwardBatch, bs: int
2324
2357
  ):
2325
2358
  assert forward_batch.spec_info is not None
2326
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
2359
+ assert forward_batch.spec_info.is_draft_input()
2327
2360
 
2328
2361
  for i in range(self.speculative_num_steps - 1):
2329
2362
  # TODO: incrementally update the metadata for the later steps,