sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,887 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
6
+
7
+ import torch
8
+
9
+ from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
10
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
+ from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
12
+ from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
13
+ from sglang.srt.layers.attention.nsa.transform_index import (
14
+ transform_index_page_table_decode,
15
+ transform_index_page_table_prefill,
16
+ )
17
+ from sglang.srt.layers.attention.nsa.utils import (
18
+ NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
19
+ NSA_FUSE_TOPK,
20
+ compute_nsa_seqlens,
21
+ )
22
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
24
+ from sglang.srt.utils import is_hip
25
+
26
+ # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
27
+
28
+ if TYPE_CHECKING:
29
+ from sglang.srt.layers.radix_attention import RadixAttention
30
+ from sglang.srt.model_executor.model_runner import ModelRunner
31
+ from sglang.srt.speculative.spec_info import SpecInput
32
+
33
+ _is_hip = is_hip()
34
+
35
+ if _is_hip:
36
+ try:
37
+ from aiter import (
38
+ flash_attn_varlen_func,
39
+ mha_batch_prefill_func,
40
+ paged_attention_ragged,
41
+ )
42
+ from aiter.mla import mla_decode_fwd, mla_prefill_fwd
43
+ except ImportError:
44
+ print(
45
+ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
46
+ )
47
+ else:
48
+ from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class NSAFlashMLAMetadata:
53
+ """Metadata only needed by FlashMLA"""
54
+
55
+ flashmla_metadata: torch.Tensor
56
+ num_splits: torch.Tensor
57
+
58
+ def slice(self, sli):
59
+ return NSAFlashMLAMetadata(
60
+ flashmla_metadata=self.flashmla_metadata,
61
+ num_splits=self.num_splits[sli],
62
+ )
63
+
64
+ def copy_(self, other: "NSAFlashMLAMetadata"):
65
+ self.flashmla_metadata.copy_(other.flashmla_metadata)
66
+ self.num_splits.copy_(other.num_splits)
67
+
68
+
69
+ @dataclass(frozen=True)
70
+ class NSAMetadata:
71
+ page_size: int
72
+
73
+ # Sequence lengths for the forward batch
74
+ cache_seqlens_int32: torch.Tensor
75
+ # Maximum sequence length for query
76
+ max_seq_len_q: int
77
+ # Maximum sequence length for key
78
+ max_seq_len_k: int
79
+ # Cumulative sequence lengths for query
80
+ cu_seqlens_q: torch.Tensor
81
+ # Cumulative sequence lengths for key
82
+ cu_seqlens_k: torch.Tensor
83
+ # Page table, the index of KV Cache Tables/Blocks
84
+ # this table is always with page_size = 1
85
+ page_table_1: torch.Tensor
86
+
87
+ # NOTE(dark): This will property be used in:
88
+ # 1. dense decode/prefill, we use paged flash attention, need real_page_table
89
+ # 2. sparse decode/prefill, indexer need real_page_table to compute the score
90
+ real_page_table: torch.Tensor
91
+
92
+ # NSA metadata (nsa prefill are expanded)
93
+ nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
94
+ nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
95
+ nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
96
+ nsa_extend_seq_lens_list: List[int]
97
+ nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
98
+ nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
99
+
100
+ flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
101
+
102
+
103
+ @dataclass(frozen=True)
104
+ class NSAIndexerMetadata(BaseIndexerMetadata):
105
+ attn_metadata: NSAMetadata
106
+
107
+ def get_seqlens_int32(self) -> torch.Tensor:
108
+ return self.attn_metadata.cache_seqlens_int32
109
+
110
+ def get_page_table_64(self) -> torch.Tensor:
111
+ return self.attn_metadata.real_page_table
112
+
113
+ def get_seqlens_expanded(self) -> torch.Tensor:
114
+ return self.attn_metadata.nsa_seqlens_expanded
115
+
116
+ def topk_transform(
117
+ self,
118
+ logits: torch.Tensor,
119
+ topk: int,
120
+ ) -> torch.Tensor:
121
+ from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
122
+
123
+ if not NSA_FUSE_TOPK:
124
+ return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
125
+
126
+ # NOTE(dark): if fused, we return a transformed page table directly
127
+ return fast_topk_transform_fused(
128
+ score=logits,
129
+ lengths=self.get_seqlens_expanded(),
130
+ page_table_size_1=self.attn_metadata.page_table_1,
131
+ cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
132
+ topk=topk,
133
+ )
134
+
135
+
136
+ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
137
+ assert seqlens.dtype == torch.int32 and seqlens.is_cuda
138
+ return torch.nn.functional.pad(
139
+ torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
140
+ )
141
+
142
+
143
+ _NSA_IMPL_T: TypeAlias = Literal[
144
+ "flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
145
+ ]
146
+
147
+ NSA_PREFILL_IMPL: _NSA_IMPL_T
148
+ NSA_DECODE_IMPL: _NSA_IMPL_T
149
+
150
+
151
+ class NativeSparseAttnBackend(AttentionBackend):
152
+ def __init__(self, model_runner: ModelRunner):
153
+ super().__init__()
154
+ self.forward_metadata: NSAMetadata
155
+ self.device = model_runner.device
156
+ assert isinstance(model_runner.page_size, int)
157
+ self.real_page_size = model_runner.page_size
158
+ self.num_splits = (
159
+ 1 if model_runner.server_args.enable_deterministic_inference else 0
160
+ )
161
+ self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
162
+ assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
163
+ self.nsa_kv_cache_store_fp8 = (
164
+ model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8
165
+ )
166
+ self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
167
+ self.max_context_len = model_runner.model_config.context_len
168
+ self.num_q_heads = (
169
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
170
+ )
171
+ self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
172
+
173
+ assert model_runner.req_to_token_pool is not None
174
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
175
+
176
+ global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
177
+ NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
178
+ NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
179
+
180
+ self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
181
+
182
+ if _is_hip:
183
+ max_bs = model_runner.req_to_token_pool.size
184
+
185
+ self.kv_indptr = torch.zeros(
186
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
187
+ )
188
+
189
+ def get_device_int32_arange(self, l: int) -> torch.Tensor:
190
+ if l > len(self._arange_buf):
191
+ next_pow_of_2 = 1 << (l - 1).bit_length()
192
+ self._arange_buf = torch.arange(
193
+ next_pow_of_2, device=self.device, dtype=torch.int32
194
+ )
195
+ return self._arange_buf[:l]
196
+
197
+ def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
198
+ page_size = self.real_page_size
199
+ if page_size == 1:
200
+ return page_table
201
+ max_seqlen_k = page_table.shape[1]
202
+ strided_indices = torch.arange(
203
+ 0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
204
+ )
205
+ return page_table[:, strided_indices] // page_size
206
+
207
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
208
+ """Init the metadata for a forward pass."""
209
+ batch_size = forward_batch.batch_size
210
+ device = forward_batch.seq_lens.device
211
+
212
+ assert (
213
+ forward_batch.spec_info is None
214
+ ), "Spec decoding is not supported for NSA backend now"
215
+ cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
216
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
217
+ assert forward_batch.seq_lens_cpu is not None
218
+ max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
219
+ page_table = forward_batch.req_to_token_pool.req_to_token[
220
+ forward_batch.req_pool_indices, :max_seqlen_k
221
+ ]
222
+
223
+ if forward_batch.forward_mode.is_decode_or_idle():
224
+ extend_seq_lens_cpu = [1] * batch_size
225
+ max_seqlen_q = 1
226
+ cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
227
+ seqlens_expanded = cache_seqlens_int32
228
+ elif forward_batch.forward_mode.is_extend():
229
+ assert (
230
+ forward_batch.extend_seq_lens_cpu is not None
231
+ and forward_batch.extend_seq_lens is not None
232
+ and forward_batch.extend_prefix_lens_cpu is not None
233
+ ), "All of them must not be None"
234
+ extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
235
+ assert forward_batch.extend_seq_lens is not None
236
+ if any(forward_batch.extend_prefix_lens_cpu):
237
+ max_seqlen_q = max(extend_seq_lens_cpu)
238
+ cu_seqlens_q = compute_cu_seqlens(
239
+ forward_batch.extend_seq_lens.to(torch.int32)
240
+ )
241
+ else:
242
+ max_seqlen_q = max_seqlen_k
243
+ cu_seqlens_q = cu_seqlens_k
244
+ seqlens_expanded = torch.cat(
245
+ [
246
+ torch.arange(
247
+ kv_len - qo_len + 1,
248
+ kv_len + 1,
249
+ dtype=torch.int32,
250
+ device=device,
251
+ )
252
+ for qo_len, kv_len in zip(
253
+ forward_batch.extend_seq_lens_cpu,
254
+ forward_batch.seq_lens_cpu.tolist(),
255
+ strict=True,
256
+ )
257
+ ]
258
+ )
259
+ else:
260
+ assert False, f"Unsupported {forward_batch.forward_mode = }"
261
+
262
+ # 1D, expanded seqlens (1D means cheap to compute, so always compute it)
263
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
264
+ original_seq_lens=seqlens_expanded,
265
+ nsa_index_topk=self.nsa_index_topk,
266
+ )
267
+ nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
268
+ nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
269
+
270
+ metadata = NSAMetadata(
271
+ page_size=self.real_page_size,
272
+ cache_seqlens_int32=cache_seqlens_int32,
273
+ max_seq_len_q=max_seqlen_q,
274
+ max_seq_len_k=max_seqlen_k,
275
+ cu_seqlens_q=cu_seqlens_q,
276
+ cu_seqlens_k=cu_seqlens_k,
277
+ page_table_1=page_table,
278
+ flashmla_metadata=(
279
+ self._compute_flashmla_metadata(
280
+ cache_seqlens=nsa_cache_seqlens_int32,
281
+ seq_len_q=1, # TODO handle MTP which is not 1
282
+ )
283
+ if NSA_DECODE_IMPL == "flashmla_decode"
284
+ else None
285
+ ),
286
+ nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
287
+ nsa_cu_seqlens_q=nsa_cu_seqlens_q,
288
+ nsa_cu_seqlens_k=nsa_cu_seqlens_k,
289
+ nsa_seqlens_expanded=seqlens_expanded,
290
+ nsa_extend_seq_lens_list=extend_seq_lens_cpu,
291
+ real_page_table=self._transform_table_1_to_real(page_table),
292
+ )
293
+
294
+ self.forward_metadata = metadata
295
+
296
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
297
+ """Initialize CUDA graph state for the attention backend.
298
+
299
+ Args:
300
+ max_bs (int): Maximum batch size to support in CUDA graphs
301
+
302
+ This creates fixed-size tensors that will be reused during CUDA graph replay
303
+ to avoid memory allocations.
304
+ """
305
+ self.decode_cuda_graph_metadata: Dict = {
306
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
307
+ "cu_seqlens_q": torch.arange(
308
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
309
+ ),
310
+ "cu_seqlens_k": torch.zeros(
311
+ max_bs + 1, dtype=torch.int32, device=self.device
312
+ ),
313
+ # fake page_table for sparse_prefill
314
+ "page_table": torch.zeros(
315
+ max_bs,
316
+ self.max_context_len,
317
+ dtype=torch.int32,
318
+ device=self.device,
319
+ ),
320
+ "flashmla_metadata": (
321
+ self._compute_flashmla_metadata(
322
+ cache_seqlens=torch.ones(
323
+ max_bs, dtype=torch.int32, device=self.device
324
+ ),
325
+ seq_len_q=1, # TODO handle MTP which is not 1
326
+ )
327
+ if NSA_DECODE_IMPL == "flashmla_decode"
328
+ else None
329
+ ),
330
+ }
331
+
332
+ def init_forward_metadata_capture_cuda_graph(
333
+ self,
334
+ bs: int,
335
+ num_tokens: int,
336
+ req_pool_indices: torch.Tensor,
337
+ seq_lens: torch.Tensor,
338
+ encoder_lens: Optional[torch.Tensor],
339
+ forward_mode: ForwardMode,
340
+ spec_info: Optional[SpecInput],
341
+ ):
342
+ """Initialize forward metadata for capturing CUDA graph."""
343
+ assert forward_mode.is_decode_or_idle(), "Only support decode for now"
344
+ assert (
345
+ spec_info is None
346
+ ), "Speculative decoding is not supported for NSA backend now"
347
+
348
+ # Normal Decode
349
+ # Get sequence information
350
+ cache_seqlens_int32 = seq_lens.to(torch.int32)
351
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
352
+
353
+ # Use max context length for seq_len_k
354
+ page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
355
+ max_seq_len_k = page_table_1.shape[1]
356
+
357
+ # Precompute page table
358
+ # Precompute cumulative sequence lengths
359
+
360
+ # NOTE(dark): this is always arange, since we are decoding
361
+ cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
362
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
363
+ cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
364
+ )
365
+ nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
366
+ nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
367
+ real_page_table = self._transform_table_1_to_real(page_table_1)
368
+
369
+ if NSA_DECODE_IMPL == "flashmla_decode":
370
+ flashmla_metadata = self.decode_cuda_graph_metadata[
371
+ "flashmla_metadata"
372
+ ].slice(slice(0, bs + 1))
373
+ flashmla_metadata.copy_(
374
+ self._compute_flashmla_metadata(
375
+ cache_seqlens=nsa_cache_seqlens_int32,
376
+ seq_len_q=1, # TODO handle MTP which is not 1
377
+ )
378
+ )
379
+ else:
380
+ flashmla_metadata = None
381
+
382
+ metadata = NSAMetadata(
383
+ page_size=self.real_page_size,
384
+ cache_seqlens_int32=cache_seqlens_int32,
385
+ max_seq_len_q=1,
386
+ max_seq_len_k=max_seq_len_k,
387
+ cu_seqlens_q=cu_seqlens_q,
388
+ cu_seqlens_k=cu_seqlens_k,
389
+ page_table_1=page_table_1,
390
+ flashmla_metadata=flashmla_metadata,
391
+ nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
392
+ nsa_cu_seqlens_q=nsa_cu_seqlens_q,
393
+ nsa_cu_seqlens_k=nsa_cu_seqlens_k,
394
+ nsa_seqlens_expanded=cache_seqlens_int32,
395
+ real_page_table=real_page_table,
396
+ nsa_extend_seq_lens_list=[1] * bs,
397
+ )
398
+ self.decode_cuda_graph_metadata[bs] = metadata
399
+ self.forward_metadata = metadata
400
+
401
+ def init_forward_metadata_replay_cuda_graph(
402
+ self,
403
+ bs: int,
404
+ req_pool_indices: torch.Tensor,
405
+ seq_lens: torch.Tensor,
406
+ seq_lens_sum: int,
407
+ encoder_lens: Optional[torch.Tensor],
408
+ forward_mode: ForwardMode,
409
+ spec_info: Optional[SpecInput],
410
+ seq_lens_cpu: Optional[torch.Tensor],
411
+ out_cache_loc: Optional[torch.Tensor] = None,
412
+ ):
413
+ """Initialize forward metadata for replaying CUDA graph."""
414
+ assert seq_lens_cpu is not None
415
+ assert forward_mode.is_decode_or_idle(), "Only support decode for now"
416
+ assert (
417
+ spec_info is None
418
+ ), "Speculative decoding is not supported for NSA backend now"
419
+ seq_lens = seq_lens[:bs]
420
+ seq_lens_cpu = seq_lens_cpu[:bs]
421
+ req_pool_indices = req_pool_indices[:bs]
422
+
423
+ # Normal Decode
424
+ metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
425
+ max_len = int(seq_lens_cpu.max().item())
426
+
427
+ cache_seqlens = seq_lens.to(torch.int32)
428
+ metadata.cache_seqlens_int32.copy_(cache_seqlens)
429
+ metadata.cu_seqlens_k[1:].copy_(
430
+ torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
431
+ )
432
+ page_indices = self.req_to_token[req_pool_indices, :max_len]
433
+ metadata.page_table_1[:, :max_len].copy_(page_indices)
434
+ assert (
435
+ metadata.nsa_cache_seqlens_int32 is not None
436
+ and metadata.nsa_cu_seqlens_k is not None
437
+ and self.nsa_index_topk is not None
438
+ )
439
+ nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
440
+ metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
441
+ metadata.nsa_cu_seqlens_k[1:].copy_(
442
+ torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
443
+ )
444
+ # NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
445
+
446
+ assert self.real_page_size == metadata.page_size
447
+ if self.real_page_size > 1:
448
+ real_table = self._transform_table_1_to_real(page_indices)
449
+ new_len = real_table.shape[1]
450
+ metadata.real_page_table[:, :new_len].copy_(real_table)
451
+ else:
452
+ assert metadata.real_page_table is metadata.page_table_1
453
+
454
+ if NSA_DECODE_IMPL == "flashmla_decode":
455
+ metadata.flashmla_metadata.copy_(
456
+ self._compute_flashmla_metadata(
457
+ cache_seqlens=nsa_cache_seqlens,
458
+ seq_len_q=1, # TODO handle MTP which is not 1
459
+ )
460
+ )
461
+
462
+ self.forward_metadata = metadata
463
+
464
+ def forward_extend(
465
+ self,
466
+ q: torch.Tensor,
467
+ k: torch.Tensor,
468
+ v: torch.Tensor,
469
+ layer: RadixAttention,
470
+ forward_batch: ForwardBatch,
471
+ save_kv_cache=True,
472
+ # For multi-head latent attention
473
+ q_rope: Optional[torch.Tensor] = None,
474
+ k_rope: Optional[torch.Tensor] = None,
475
+ topk_indices: Optional[torch.Tensor] = None,
476
+ ) -> torch.Tensor:
477
+ assert (
478
+ not forward_batch.forward_mode.is_target_verify()
479
+ and not forward_batch.forward_mode.is_draft_extend()
480
+ ), "NSA backend doesn't support speculative decoding"
481
+ if k is not None:
482
+ assert v is not None
483
+ if save_kv_cache:
484
+ cache_loc = (
485
+ forward_batch.out_cache_loc
486
+ if not layer.is_cross_attention
487
+ else forward_batch.encoder_out_cache_loc
488
+ )
489
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
490
+ layer,
491
+ cache_loc,
492
+ k,
493
+ k_rope,
494
+ )
495
+
496
+ metadata = self.forward_metadata
497
+ causal = not layer.is_cross_attention
498
+ assert causal, "NSA is causal only"
499
+
500
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
501
+ kwargs = {}
502
+
503
+ # Do absorbed multi-latent attention
504
+ assert q_rope is not None
505
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
506
+
507
+ # when store in fp8 and compute in fp8, no need to convert dtype
508
+ if not (
509
+ NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and self.nsa_kv_cache_store_fp8
510
+ ):
511
+ kv_cache = kv_cache.to(q.dtype)
512
+
513
+ if q_rope is not None:
514
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
515
+ q_rope = q_rope.view(
516
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
517
+ )
518
+ else:
519
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
520
+ q_nope = q_all[:, :, : layer.v_head_dim]
521
+ q_rope = q_all[:, :, layer.v_head_dim :]
522
+
523
+ # NOTE(dark): here, we use page size = 1
524
+
525
+ if NSA_FUSE_TOPK:
526
+ page_table_1 = topk_indices
527
+ else:
528
+ assert metadata.nsa_extend_seq_lens_list is not None
529
+ page_table_1 = transform_index_page_table_prefill(
530
+ page_table=metadata.page_table_1,
531
+ topk_indices=topk_indices,
532
+ extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
533
+ page_size=1,
534
+ )
535
+ if NSA_PREFILL_IMPL == "tilelang":
536
+ if q_rope is not None:
537
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
538
+ return self._forward_tilelang(
539
+ q_all=q_all,
540
+ kv_cache=kv_cache,
541
+ page_table_1=page_table_1,
542
+ sm_scale=layer.scaling,
543
+ v_head_dim=layer.v_head_dim,
544
+ )
545
+ elif NSA_PREFILL_IMPL == "flashmla_prefill":
546
+ if q_rope is not None:
547
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
548
+ return self._forward_flashmla_prefill(
549
+ q_all=q_all,
550
+ kv_cache=kv_cache,
551
+ page_table_1=page_table_1,
552
+ sm_scale=layer.scaling,
553
+ v_head_dim=layer.v_head_dim,
554
+ )
555
+ elif NSA_PREFILL_IMPL == "flashmla_decode":
556
+ if q_rope is not None:
557
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
558
+ return self._forward_flashmla_decode(
559
+ q_all=q_all,
560
+ kv_cache=kv_cache,
561
+ sm_scale=layer.scaling,
562
+ v_head_dim=layer.v_head_dim,
563
+ # TODO optimize args
564
+ layer=layer,
565
+ metadata=metadata,
566
+ page_table_1=page_table_1,
567
+ )
568
+ elif NSA_PREFILL_IMPL == "fa3":
569
+ return self._forward_fa3(
570
+ q_rope=q_rope,
571
+ kv_cache=kv_cache,
572
+ v_head_dim=layer.v_head_dim,
573
+ q_nope=q_nope,
574
+ page_table=page_table_1,
575
+ cache_seqlens=metadata.nsa_cache_seqlens_int32,
576
+ cu_seqlens_q=metadata.nsa_cu_seqlens_q,
577
+ cu_seqlens_k=metadata.nsa_cu_seqlens_k,
578
+ max_seqlen_q=metadata.nsa_max_seqlen_q,
579
+ sm_scale=layer.scaling,
580
+ logit_cap=layer.logit_cap,
581
+ page_size=1,
582
+ )
583
+ else:
584
+ raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
585
+
586
+ def forward_decode(
587
+ self,
588
+ q: torch.Tensor,
589
+ k: torch.Tensor,
590
+ v: torch.Tensor,
591
+ layer: RadixAttention,
592
+ forward_batch: ForwardBatch,
593
+ save_kv_cache=True,
594
+ # For multi-head latent attention
595
+ q_rope: Optional[torch.Tensor] = None,
596
+ k_rope: Optional[torch.Tensor] = None,
597
+ topk_indices: Optional[torch.Tensor] = None,
598
+ ) -> torch.Tensor:
599
+ if k is not None:
600
+ assert v is not None
601
+ if save_kv_cache:
602
+ cache_loc = (
603
+ forward_batch.out_cache_loc
604
+ if not layer.is_cross_attention
605
+ else forward_batch.encoder_out_cache_loc
606
+ )
607
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
608
+ layer,
609
+ cache_loc,
610
+ k,
611
+ k_rope,
612
+ )
613
+
614
+ metadata = self.forward_metadata
615
+ causal = not layer.is_cross_attention
616
+ assert causal, "NSA is causal only"
617
+
618
+ # Do absorbed multi-latent attention
619
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
620
+ if q_rope is not None:
621
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
622
+ q_rope = q_rope.view(
623
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
624
+ )
625
+ else:
626
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
627
+ q_nope = q_all[:, :, : layer.v_head_dim]
628
+ q_rope = q_all[:, :, layer.v_head_dim :]
629
+
630
+ if NSA_FUSE_TOPK:
631
+ page_table_1 = topk_indices
632
+ else:
633
+ page_table_1 = transform_index_page_table_decode(
634
+ page_table=metadata.page_table_1,
635
+ topk_indices=topk_indices,
636
+ page_size=1,
637
+ )
638
+
639
+ if NSA_DECODE_IMPL == "flashmla_prefill":
640
+ if q_rope is not None:
641
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
642
+ return self._forward_flashmla_prefill(
643
+ q_all=q_all,
644
+ kv_cache=kv_cache,
645
+ page_table_1=page_table_1,
646
+ sm_scale=layer.scaling,
647
+ v_head_dim=layer.v_head_dim,
648
+ )
649
+ elif NSA_DECODE_IMPL == "flashmla_decode":
650
+ if q_rope is not None:
651
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
652
+ return self._forward_flashmla_decode(
653
+ q_all=q_all,
654
+ kv_cache=kv_cache,
655
+ sm_scale=layer.scaling,
656
+ v_head_dim=layer.v_head_dim,
657
+ # TODO optimize args
658
+ layer=layer,
659
+ metadata=metadata,
660
+ page_table_1=page_table_1,
661
+ )
662
+ elif NSA_DECODE_IMPL == "tilelang":
663
+ if q_rope is not None:
664
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
665
+ return self._forward_tilelang(
666
+ q_all=q_all,
667
+ kv_cache=kv_cache,
668
+ page_table_1=page_table_1,
669
+ sm_scale=layer.scaling,
670
+ v_head_dim=layer.v_head_dim,
671
+ )
672
+ elif NSA_DECODE_IMPL == "fa3":
673
+ return self._forward_fa3(
674
+ q_rope=q_rope,
675
+ kv_cache=kv_cache,
676
+ v_head_dim=layer.v_head_dim,
677
+ q_nope=q_nope,
678
+ page_table=page_table_1,
679
+ cache_seqlens=metadata.nsa_cache_seqlens_int32,
680
+ cu_seqlens_q=metadata.nsa_cu_seqlens_q,
681
+ cu_seqlens_k=metadata.nsa_cu_seqlens_k,
682
+ max_seqlen_q=metadata.nsa_max_seqlen_q,
683
+ sm_scale=layer.scaling,
684
+ logit_cap=layer.logit_cap,
685
+ page_size=1,
686
+ )
687
+ elif NSA_DECODE_IMPL == "aiter":
688
+ if q_rope is not None:
689
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
690
+ return self._forward_aiter(
691
+ q_all=q_all,
692
+ kv_cache=kv_cache,
693
+ page_table_1=page_table_1,
694
+ layer=layer,
695
+ metadata=metadata,
696
+ bs=forward_batch.batch_size,
697
+ )
698
+
699
+ else:
700
+ assert False, f"Unsupported {NSA_DECODE_IMPL = }"
701
+
702
+ def _forward_fa3(
703
+ self,
704
+ q_rope: torch.Tensor,
705
+ kv_cache: torch.Tensor,
706
+ v_head_dim: int,
707
+ q_nope: torch.Tensor,
708
+ page_table: torch.Tensor,
709
+ cache_seqlens: torch.Tensor,
710
+ cu_seqlens_q: torch.Tensor,
711
+ cu_seqlens_k: torch.Tensor,
712
+ max_seqlen_q: int,
713
+ sm_scale: float,
714
+ logit_cap: float,
715
+ page_size: int,
716
+ ) -> torch.Tensor:
717
+ k_rope_cache = kv_cache[:, :, v_head_dim:]
718
+ c_kv_cache = kv_cache[:, :, :v_head_dim]
719
+ qk_rope_dim = k_rope_cache.shape[-1]
720
+ k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
721
+ c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
722
+ o = flash_attn_with_kvcache(
723
+ q=q_rope,
724
+ k_cache=k_rope_cache,
725
+ v_cache=c_kv_cache,
726
+ qv=q_nope,
727
+ page_table=page_table,
728
+ cache_seqlens=cache_seqlens,
729
+ cu_seqlens_q=cu_seqlens_q,
730
+ cu_seqlens_k_new=cu_seqlens_k,
731
+ max_seqlen_q=max_seqlen_q,
732
+ softmax_scale=sm_scale,
733
+ causal=True,
734
+ softcap=logit_cap,
735
+ return_softmax_lse=False,
736
+ num_splits=self.num_splits,
737
+ )
738
+ return o # type: ignore
739
+
740
+ def _forward_flashmla_prefill(
741
+ self,
742
+ q_all: torch.Tensor,
743
+ kv_cache: torch.Tensor,
744
+ v_head_dim: int,
745
+ page_table_1: torch.Tensor,
746
+ sm_scale: float,
747
+ ) -> torch.Tensor:
748
+ from flash_mla import flash_mla_sparse_fwd
749
+
750
+ o, _, _ = flash_mla_sparse_fwd(
751
+ q=q_all,
752
+ kv=kv_cache,
753
+ indices=page_table_1.unsqueeze(1),
754
+ sm_scale=sm_scale,
755
+ d_v=v_head_dim,
756
+ )
757
+ return o
758
+
759
+ def _forward_flashmla_decode(
760
+ self,
761
+ q_all: torch.Tensor,
762
+ kv_cache: torch.Tensor,
763
+ v_head_dim: int,
764
+ sm_scale: float,
765
+ layer,
766
+ metadata: NSAMetadata,
767
+ page_table_1,
768
+ ) -> torch.Tensor:
769
+ from flash_mla import flash_mla_with_kvcache
770
+
771
+ cache_seqlens = metadata.nsa_cache_seqlens_int32
772
+
773
+ # TODO the 2nd dim is seq_len_q, need to be >1 when MTP
774
+ q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
775
+ kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
776
+ assert self.real_page_size == 64, "only page size 64 is supported"
777
+
778
+ if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not self.nsa_kv_cache_store_fp8:
779
+ # inefficiently quantize the whole cache
780
+ kv_cache = quantize_k_cache(kv_cache)
781
+
782
+ indices = page_table_1.unsqueeze(1)
783
+ assert (
784
+ indices.shape[-1] == self.nsa_index_topk
785
+ ) # requirement of FlashMLA decode kernel
786
+
787
+ o, _ = flash_mla_with_kvcache(
788
+ q=q_all,
789
+ k_cache=kv_cache,
790
+ cache_seqlens=cache_seqlens,
791
+ head_dim_v=v_head_dim,
792
+ tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
793
+ num_splits=metadata.flashmla_metadata.num_splits,
794
+ softmax_scale=sm_scale,
795
+ indices=indices,
796
+ # doc says it is not used, but if pass in None then error
797
+ block_table=torch.empty(
798
+ (q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
799
+ ),
800
+ is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
801
+ )
802
+ return o
803
+
804
+ def _forward_tilelang(
805
+ self,
806
+ q_all: torch.Tensor,
807
+ kv_cache: torch.Tensor,
808
+ v_head_dim: int,
809
+ page_table_1: torch.Tensor,
810
+ sm_scale: float,
811
+ ) -> torch.Tensor:
812
+ from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
813
+
814
+ return tilelang_sparse_fwd(
815
+ q=q_all,
816
+ kv=kv_cache,
817
+ indices=page_table_1.unsqueeze(1),
818
+ sm_scale=sm_scale,
819
+ d_v=v_head_dim,
820
+ )
821
+
822
+ def _forward_aiter(
823
+ self,
824
+ q_all: torch.Tensor,
825
+ kv_cache: torch.Tensor,
826
+ page_table_1: torch.Tensor,
827
+ layer: RadixAttention,
828
+ metadata: NSAMetadata,
829
+ bs: int,
830
+ ) -> torch.Tensor:
831
+ q = q_all.reshape(-1, layer.tp_q_head_num * layer.head_dim)
832
+
833
+ if layer.head_dim != layer.v_head_dim:
834
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
835
+ else:
836
+ o = torch.empty_like(q)
837
+
838
+ kv_indptr = self.kv_indptr
839
+
840
+ non_minus1_mask = page_table_1 != -1
841
+ non_minus1_counts = non_minus1_mask.sum(dim=1)
842
+ kv_indptr[1 : bs + 1] = torch.cumsum(non_minus1_counts, dim=0)
843
+
844
+ kv_indices = page_table_1[page_table_1 != -1]
845
+
846
+ mla_decode_fwd(
847
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
848
+ kv_cache.view(-1, 1, 1, layer.head_dim),
849
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
850
+ metadata.cu_seqlens_q,
851
+ kv_indptr,
852
+ kv_indices,
853
+ metadata.cu_seqlens_q,
854
+ metadata.max_seq_len_q,
855
+ layer.scaling,
856
+ layer.logit_cap,
857
+ )
858
+ # kv_cache = kv_cache.view(-1, 1, layer.head_dim)
859
+ return o
860
+
861
+ def get_cuda_graph_seq_len_fill_value(self):
862
+ """Get the fill value for sequence length in CUDA graph."""
863
+ return 1
864
+
865
+ def get_indexer_metadata(
866
+ self, layer_id: int, forward_batch: ForwardBatch
867
+ ) -> NSAIndexerMetadata:
868
+ return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
869
+
870
+ def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
871
+ from flash_mla import get_mla_metadata
872
+
873
+ flashmla_metadata, num_splits = get_mla_metadata(
874
+ cache_seqlens=cache_seqlens,
875
+ # TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
876
+ # but the name looks like need seq_len_q?
877
+ num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
878
+ num_heads_k=1,
879
+ num_heads_q=self.num_q_heads,
880
+ is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
881
+ topk=self.nsa_index_topk,
882
+ )
883
+
884
+ return NSAFlashMLAMetadata(
885
+ flashmla_metadata=flashmla_metadata,
886
+ num_splits=num_splits,
887
+ )