sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import contextlib
5
+ import functools
6
+ import logging
7
+ import os
8
+ import sys
9
+ from enum import Enum
10
+ from functools import lru_cache
11
+ from typing import Any, Callable, Dict, Literal, Optional, Tuple
12
+
13
+ import torch
14
+ import triton
15
+ from packaging import version
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
20
+ FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
21
+
22
+
23
+ @lru_cache(maxsize=1)
24
+ def check_environments():
25
+ """
26
+ Checks the current operating system, Triton version, and Python version,
27
+ issuing warnings if they don't meet recommendations.
28
+ This function's body only runs once due to lru_cache.
29
+ """
30
+ # Check Operating System
31
+ if sys.platform == "win32":
32
+ logger.warning(
33
+ "Detected Windows operating system. Triton does not have an official Windows release, "
34
+ "thus FLA will not be adapted for Windows, and any potential errors will not be fixed. "
35
+ "Please consider using a Linux environment for compatibility."
36
+ )
37
+
38
+ triton_version = version.parse(triton.__version__)
39
+ required_triton_version = version.parse("3.2.0")
40
+
41
+ if triton_version < required_triton_version:
42
+ logger.warning(
43
+ f"Current Triton version {triton_version} is below the recommended 3.2.0 version. "
44
+ "Errors may occur and these issues will not be fixed. "
45
+ "Please consider upgrading Triton."
46
+ )
47
+
48
+ # Check Python version
49
+ py_version = version.parse(f"{sys.version_info.major}.{sys.version_info.minor}")
50
+ required_py_version = version.parse("3.11")
51
+
52
+ if py_version < required_py_version:
53
+ logger.warning(
54
+ f"Current Python version {py_version} is below the recommended 3.11 version. "
55
+ "It is recommended to upgrade to Python 3.11 or higher for the best experience."
56
+ )
57
+
58
+ return None
59
+
60
+
61
+ check_environments()
62
+
63
+
64
+ def get_abs_err(x, y):
65
+ return (x.detach() - y.detach()).flatten().abs().max().item()
66
+
67
+
68
+ def get_err_ratio(x, y):
69
+ err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
70
+ base = (x.detach()).flatten().square().mean().sqrt().item()
71
+ return err / (base + 1e-8)
72
+
73
+
74
+ def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
75
+ abs_atol = get_abs_err(ref, tri)
76
+ msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
77
+ logger.info(msg)
78
+ error_rate = get_err_ratio(ref, tri)
79
+ if abs_atol <= err_atol:
80
+ return
81
+ if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
82
+ if error_rate > ratio:
83
+ import warnings
84
+
85
+ warnings.warn(msg)
86
+ else:
87
+ assert error_rate < ratio, msg
88
+
89
+
90
+ SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
91
+
92
+
93
+ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
94
+ """
95
+ A decorator that caches the most recent results of a function with tensor inputs.
96
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
97
+ The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
98
+ Args:
99
+ fn (Callable[..., torch.Tensor]):
100
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
101
+ Returns:
102
+ Callable[..., torch.Tensor]:
103
+ A wrapped version of the input function with single-entry caching.
104
+ """
105
+
106
+ cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = []
107
+ cache_size = 4
108
+
109
+ @functools.wraps(fn)
110
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
111
+ nonlocal cache_entries, cache_size
112
+ for i, entry in enumerate(cache_entries):
113
+ last_args, last_kwargs, last_result = entry
114
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
115
+ if all(a is b for a, b in zip(args, last_args)) and all(
116
+ k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
117
+ ):
118
+ cache_entries = (
119
+ cache_entries[:i]
120
+ + cache_entries[i + 1 :]
121
+ + [(args, kwargs, last_result)]
122
+ )
123
+ return last_result
124
+
125
+ result = fn(*args, **kwargs)
126
+
127
+ if len(cache_entries) >= cache_size:
128
+ cache_entries = cache_entries[1:]
129
+ cache_entries.append((args, kwargs, result))
130
+ return result
131
+
132
+ return wrapper
133
+
134
+
135
+ def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
136
+ """
137
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
138
+ """
139
+
140
+ @functools.wraps(fn)
141
+ def wrapper(*args, **kwargs):
142
+ contiguous_args = (
143
+ i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
144
+ )
145
+ contiguous_kwargs = {
146
+ k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
147
+ for k, v in kwargs.items()
148
+ }
149
+
150
+ tensor = None
151
+ for arg in args:
152
+ if isinstance(arg, torch.Tensor):
153
+ tensor = arg
154
+ break
155
+ if tensor is None:
156
+ for value in kwargs.values():
157
+ if isinstance(value, torch.Tensor):
158
+ tensor = value
159
+ break
160
+
161
+ if tensor is not None:
162
+ ctx = custom_device_ctx(tensor.device.index)
163
+ else:
164
+ ctx = contextlib.nullcontext()
165
+
166
+ with ctx:
167
+ return fn(*contiguous_args, **contiguous_kwargs)
168
+
169
+ return wrapper
170
+
171
+
172
+ contiguous = input_guard
173
+
174
+
175
+ def require_version(version, hint):
176
+ """
177
+ Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
178
+ """
179
+
180
+ def decorator(fn):
181
+ @functools.wraps(fn)
182
+ def wrapper(ctx, *args, **kwargs):
183
+ from transformers.utils.versions import require_version
184
+
185
+ require_version(version, hint)
186
+ return fn(
187
+ ctx,
188
+ *(
189
+ i if not isinstance(i, torch.Tensor) else i.contiguous()
190
+ for i in args
191
+ ),
192
+ **{
193
+ k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
194
+ for k, v in kwargs.items()
195
+ },
196
+ )
197
+
198
+ return wrapper
199
+
200
+ return decorator
201
+
202
+
203
+ def checkpoint(fn):
204
+ def wrapper(*args, **kwargs):
205
+ return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
206
+
207
+ return wrapper
208
+
209
+
210
+ @lru_cache(maxsize=None)
211
+ def check_pytorch_version(version_s: str = "2.4") -> bool:
212
+ return version.parse(torch.__version__) >= version.parse(version_s)
213
+
214
+
215
+ def _cpu_device_warning():
216
+ import warnings
217
+
218
+ warnings.warn(
219
+ ("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
220
+ )
221
+
222
+
223
+ @lru_cache(maxsize=None)
224
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
225
+ try:
226
+ return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
227
+ "multiprocessor_count"
228
+ ]
229
+ except BaseException:
230
+ _cpu_device_warning()
231
+ return -1
232
+
233
+
234
+ @lru_cache(maxsize=None)
235
+ def get_available_device() -> str:
236
+ try:
237
+ return triton.runtime.driver.active.get_current_target().backend
238
+ except BaseException:
239
+ _cpu_device_warning()
240
+ return "cpu"
241
+
242
+
243
+ @lru_cache(maxsize=None)
244
+ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
245
+ device = get_available_device()
246
+ if device == "cuda":
247
+ return "nvidia"
248
+ elif device == "hip":
249
+ return "amd"
250
+ elif device == "xpu":
251
+ return "intel"
252
+ else:
253
+ return device
254
+
255
+
256
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
257
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
258
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
259
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
260
+ device_torch_lib = getattr(torch, device)
261
+ device_platform = _check_platform()
262
+
263
+ is_amd = device_platform == "amd"
264
+ is_intel = device_platform == "intel"
265
+ is_nvidia = device_platform == "nvidia"
266
+ is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
267
+ is_nvidia_hopper = is_nvidia and (
268
+ "NVIDIA H" in torch.cuda.get_device_name(0)
269
+ or torch.cuda.get_device_capability()[0] >= 9
270
+ )
271
+ use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
272
+
273
+ # Nvidia Ampere or newer, haven't check AMD and intel yet.
274
+ is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
275
+ is_gather_supported = hasattr(triton.language, "gather")
276
+
277
+
278
+ def get_all_max_shared_mem():
279
+ try:
280
+ return [
281
+ triton.runtime.driver.active.utils.get_device_properties(i)[
282
+ "max_shared_mem"
283
+ ]
284
+ for i in range(device_torch_lib.device_count())
285
+ ]
286
+ except BaseException:
287
+ _cpu_device_warning()
288
+ return [-1]
289
+
290
+
291
+ class Backend(Enum):
292
+ ADA = 101376 # RTX 4090
293
+ AMPERE = 166912 # A100
294
+ HOPPER = 232448 # H100
295
+ DEFAULT = 102400 # Default
296
+
297
+ @classmethod
298
+ def get_shared_memory(cls, arch: str) -> int:
299
+ try:
300
+ return cls[arch.upper()].value
301
+ except KeyError:
302
+ return cls.DEFAULT.value
303
+
304
+
305
+ @lru_cache(maxsize=None)
306
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
307
+ try:
308
+ device_shared_mem_list = get_all_max_shared_mem()
309
+ max_shared_memory = device_shared_mem_list[tensor_idx]
310
+ return max_shared_memory >= Backend.get_shared_memory(arch)
311
+ except Exception:
312
+ return False
313
+
314
+
315
+ if check_pytorch_version("2.4"):
316
+ device = "cuda" if device == "cpu" else device
317
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
318
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
319
+
320
+ def custom_device_ctx(index: int):
321
+ return device_torch_lib.device(index)
322
+
323
+ else:
324
+ assert (
325
+ device == "cuda"
326
+ ), "Only cuda device is supported for PyTorch version < 2.4.0."
327
+ autocast_custom_fwd = device_torch_lib.amp.custom_fwd
328
+ autocast_custom_bwd = device_torch_lib.amp.custom_bwd
329
+
330
+ def custom_device_ctx(index: int):
331
+ return torch.cuda.device(index)
@@ -0,0 +1,158 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
12
+ from sglang.srt.layers.attention.fla.op import safe_exp
13
+ from sglang.srt.layers.attention.fla.utils import check_shared_mem
14
+
15
+
16
+ @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
17
+ # @triton.autotune(
18
+ # configs=[
19
+ # triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ # for num_warps in [2, 4, 8]
21
+ # for num_stages in [2, 3, 4]
22
+ # ],
23
+ # key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
24
+ # )
25
+ @triton.jit(do_not_specialize=["T"])
26
+ def recompute_w_u_fwd_kernel(
27
+ k,
28
+ v,
29
+ beta,
30
+ w,
31
+ u,
32
+ A,
33
+ g,
34
+ cu_seqlens,
35
+ chunk_indices,
36
+ T,
37
+ H: tl.constexpr,
38
+ Hg: tl.constexpr,
39
+ K: tl.constexpr,
40
+ V: tl.constexpr,
41
+ BT: tl.constexpr,
42
+ BK: tl.constexpr,
43
+ BV: tl.constexpr,
44
+ IS_VARLEN: tl.constexpr,
45
+ ):
46
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
47
+ i_b, i_h = i_bh // H, i_bh % H
48
+ if IS_VARLEN:
49
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
50
+ chunk_indices + i_t * 2 + 1
51
+ ).to(tl.int32)
52
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
53
+ cu_seqlens + i_n + 1
54
+ ).to(tl.int32)
55
+ T = eos - bos
56
+ else:
57
+ bos, eos = i_b * T, i_b * T + T
58
+ p_beta = tl.make_block_ptr(
59
+ beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
60
+ )
61
+ p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
62
+ p_A = tl.make_block_ptr(
63
+ A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
64
+ )
65
+ b_beta = tl.load(p_beta, boundary_check=(0,))
66
+ b_A = tl.load(p_A, boundary_check=(0, 1))
67
+ b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
68
+
69
+ for i_v in range(tl.cdiv(V, BV)):
70
+ p_v = tl.make_block_ptr(
71
+ v + (bos * H + i_h) * V,
72
+ (T, V),
73
+ (H * V, 1),
74
+ (i_t * BT, i_v * BV),
75
+ (BT, BV),
76
+ (1, 0),
77
+ )
78
+ p_u = tl.make_block_ptr(
79
+ u + (bos * H + i_h) * V,
80
+ (T, V),
81
+ (H * V, 1),
82
+ (i_t * BT, i_v * BV),
83
+ (BT, BV),
84
+ (1, 0),
85
+ )
86
+ b_v = tl.load(p_v, boundary_check=(0, 1))
87
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
88
+ b_u = tl.dot(b_A, b_vb, allow_tf32=False)
89
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
90
+
91
+ for i_k in range(tl.cdiv(K, BK)):
92
+ p_k = tl.make_block_ptr(
93
+ k + (bos * Hg + i_h // (H // Hg)) * K,
94
+ (T, K),
95
+ (Hg * K, 1),
96
+ (i_t * BT, i_k * BK),
97
+ (BT, BK),
98
+ (1, 0),
99
+ )
100
+ p_w = tl.make_block_ptr(
101
+ w + (bos * H + i_h) * K,
102
+ (T, K),
103
+ (H * K, 1),
104
+ (i_t * BT, i_k * BK),
105
+ (BT, BK),
106
+ (1, 0),
107
+ )
108
+ b_k = tl.load(p_k, boundary_check=(0, 1))
109
+ b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
110
+ b_w = tl.dot(b_A, b_kb)
111
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ def recompute_w_u_fwd(
115
+ k: torch.Tensor,
116
+ v: torch.Tensor,
117
+ beta: torch.Tensor,
118
+ g_cumsum: torch.Tensor,
119
+ A: torch.Tensor,
120
+ cu_seqlens: Optional[torch.LongTensor],
121
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ B, T, Hg, K, V = *k.shape, v.shape[-1]
123
+ H = v.shape[-2]
124
+ BT = A.shape[-1]
125
+
126
+ chunk_indices = (
127
+ prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
128
+ )
129
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
130
+ BK = 64
131
+ BV = 64
132
+ u = torch.empty_like(v)
133
+ w = k.new_empty(B, T, H, K)
134
+ recompute_w_u_fwd_kernel[(NT, B * H)](
135
+ k=k,
136
+ v=v,
137
+ beta=beta,
138
+ w=w,
139
+ u=u,
140
+ A=A,
141
+ g=g_cumsum,
142
+ cu_seqlens=cu_seqlens,
143
+ chunk_indices=chunk_indices,
144
+ T=T,
145
+ H=H,
146
+ Hg=Hg,
147
+ K=K,
148
+ V=V,
149
+ BT=BT,
150
+ BK=BK,
151
+ BV=BV,
152
+ num_warps=4,
153
+ num_stages=3,
154
+ )
155
+ return w, u
156
+
157
+
158
+ fwd_recompute_w_u = recompute_w_u_fwd
@@ -11,9 +11,8 @@ import triton.language as tl
11
11
  from sglang.srt.configs.model_config import AttentionArch
12
12
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
13
13
  from sglang.srt.managers.schedule_batch import global_server_args_dict
14
- from sglang.srt.mem_cache.memory_pool import SWAKVPool
15
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
16
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
15
+ from sglang.srt.speculative.spec_info import SpecInput
17
16
 
18
17
  if TYPE_CHECKING:
19
18
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -305,6 +304,7 @@ class FlashAttentionBackend(AttentionBackend):
305
304
  speculative_step_id=0,
306
305
  topk=0,
307
306
  speculative_num_steps=0,
307
+ fa_impl_ver=3,
308
308
  ):
309
309
  super().__init__()
310
310
 
@@ -338,6 +338,8 @@ class FlashAttentionBackend(AttentionBackend):
338
338
  )
339
339
  self.speculative_step_id = speculative_step_id
340
340
 
341
+ self.fa_impl_ver = fa_impl_ver
342
+
341
343
  # Local attention settings
342
344
  self.attention_chunk_size = (
343
345
  model_runner.attention_chunk_size
@@ -352,6 +354,13 @@ class FlashAttentionBackend(AttentionBackend):
352
354
  self.sliding_window_size is not None and self.sliding_window_size > -1
353
355
  )
354
356
 
357
+ # If num_splits == 0, we use a heuristic to automatically determine the number of splits.
358
+ # We set nums splits to 1 if deterministic inference is enabled.
359
+ # See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
360
+ self.num_splits = (
361
+ 1 if model_runner.server_args.enable_deterministic_inference else 0
362
+ )
363
+
355
364
  def init_forward_metadata(self, forward_batch: ForwardBatch):
356
365
  """Initialize forward metadata hence all layers in the forward pass can reuse it."""
357
366
  metadata = FlashAttentionMetadata()
@@ -682,8 +691,13 @@ class FlashAttentionBackend(AttentionBackend):
682
691
  k_descale, v_descale = None, None
683
692
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
684
693
  # has corresponding quantization method so that layer.k_scale is not None,
685
- # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
686
- if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
694
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
695
+ # 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
696
+ if (
697
+ self.kv_cache_dtype_str != "auto"
698
+ and layer.head_dim <= 256
699
+ and self.fa_impl_ver != 4
700
+ ):
687
701
  if layer.k_scale is not None:
688
702
  descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
689
703
  k_descale = layer.k_scale.expand(descale_shape)
@@ -712,6 +726,8 @@ class FlashAttentionBackend(AttentionBackend):
712
726
 
713
727
  # For fa3 interface version compatibility, we put new fields into conditional keyword args
714
728
  kwargs = {}
729
+ if self.fa_impl_ver != 3:
730
+ kwargs["ver"] = self.fa_impl_ver
715
731
  if sinks is not None:
716
732
  kwargs["sinks"] = sinks
717
733
 
@@ -770,6 +786,7 @@ class FlashAttentionBackend(AttentionBackend):
770
786
  k_descale=k_descale,
771
787
  v_descale=v_descale,
772
788
  return_softmax_lse=use_cascade_attn,
789
+ num_splits=self.num_splits,
773
790
  **kwargs,
774
791
  )
775
792
 
@@ -791,6 +808,7 @@ class FlashAttentionBackend(AttentionBackend):
791
808
  k_descale=k_descale,
792
809
  v_descale=v_descale,
793
810
  return_softmax_lse=True,
811
+ num_splits=self.num_splits,
794
812
  **kwargs,
795
813
  )
796
814
  o, _ = merge_state_v2_wrapper(
@@ -830,6 +848,7 @@ class FlashAttentionBackend(AttentionBackend):
830
848
  softmax_scale=layer.scaling,
831
849
  causal=False,
832
850
  return_softmax_lse=True,
851
+ **kwargs,
833
852
  )
834
853
  else:
835
854
  # MHA for extend part of sequence without attending prefix kv cache
@@ -844,6 +863,7 @@ class FlashAttentionBackend(AttentionBackend):
844
863
  softmax_scale=layer.scaling,
845
864
  causal=True,
846
865
  return_softmax_lse=forward_batch.mha_return_lse,
866
+ **kwargs,
847
867
  )
848
868
  if forward_batch.mha_return_lse:
849
869
  output, lse, *rest = output
@@ -851,6 +871,7 @@ class FlashAttentionBackend(AttentionBackend):
851
871
  return output, lse
852
872
  return output
853
873
  else:
874
+ assert self.fa_impl_ver in [3], "Only FA3 support here"
854
875
  # Do absorbed multi-latent attention
855
876
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
856
877
  layer.layer_id
@@ -892,6 +913,7 @@ class FlashAttentionBackend(AttentionBackend):
892
913
  k_descale=k_descale,
893
914
  v_descale=v_descale,
894
915
  return_softmax_lse=use_cascade_attn,
916
+ num_splits=self.num_splits,
895
917
  )
896
918
  if use_cascade_attn:
897
919
  o, softmax_lse, *rest = result
@@ -913,6 +935,7 @@ class FlashAttentionBackend(AttentionBackend):
913
935
  k_descale=k_descale,
914
936
  v_descale=v_descale,
915
937
  return_softmax_lse=True,
938
+ num_splits=self.num_splits,
916
939
  )
917
940
  )
918
941
  o, _ = merge_state_v2_wrapper(
@@ -939,6 +962,7 @@ class FlashAttentionBackend(AttentionBackend):
939
962
  k_rope: Optional[torch.Tensor] = None,
940
963
  sinks: Optional[torch.Tensor] = None,
941
964
  ) -> torch.Tensor:
965
+ assert self.fa_impl_ver in [3], "Only FA3 support decoding"
942
966
  if k is not None:
943
967
  assert v is not None
944
968
  if save_kv_cache:
@@ -985,6 +1009,8 @@ class FlashAttentionBackend(AttentionBackend):
985
1009
 
986
1010
  # For fa3 interface version compatibility, we put new fields into conditional keyword args
987
1011
  kwargs = {}
1012
+ if self.fa_impl_ver != 3:
1013
+ kwargs["ver"] = self.fa_impl_ver
988
1014
  if sinks is not None:
989
1015
  kwargs["sinks"] = sinks
990
1016
 
@@ -1030,6 +1056,7 @@ class FlashAttentionBackend(AttentionBackend):
1030
1056
  softcap=layer.logit_cap,
1031
1057
  k_descale=k_descale,
1032
1058
  v_descale=v_descale,
1059
+ num_splits=self.num_splits,
1033
1060
  **kwargs,
1034
1061
  )
1035
1062
  elif use_local_attn:
@@ -1049,6 +1076,7 @@ class FlashAttentionBackend(AttentionBackend):
1049
1076
  softcap=layer.logit_cap,
1050
1077
  k_descale=k_descale,
1051
1078
  v_descale=v_descale,
1079
+ num_splits=self.num_splits,
1052
1080
  **kwargs,
1053
1081
  )
1054
1082
  else:
@@ -1077,6 +1105,7 @@ class FlashAttentionBackend(AttentionBackend):
1077
1105
  k_descale=k_descale,
1078
1106
  v_descale=v_descale,
1079
1107
  return_softmax_lse=use_cascade_attn,
1108
+ num_splits=self.num_splits,
1080
1109
  **kwargs,
1081
1110
  )
1082
1111
  if use_cascade_attn:
@@ -1098,6 +1127,7 @@ class FlashAttentionBackend(AttentionBackend):
1098
1127
  k_descale=k_descale,
1099
1128
  v_descale=v_descale,
1100
1129
  return_softmax_lse=True,
1130
+ num_splits=self.num_splits,
1101
1131
  **kwargs,
1102
1132
  )
1103
1133
  )
@@ -1153,6 +1183,7 @@ class FlashAttentionBackend(AttentionBackend):
1153
1183
  k_descale=k_descale,
1154
1184
  v_descale=v_descale,
1155
1185
  return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
1186
+ num_splits=self.num_splits,
1156
1187
  )
1157
1188
  if use_cascade_attn:
1158
1189
  o, softmax_lse, *rest = result
@@ -1173,6 +1204,7 @@ class FlashAttentionBackend(AttentionBackend):
1173
1204
  k_descale=k_descale,
1174
1205
  v_descale=v_descale,
1175
1206
  return_softmax_lse=True,
1207
+ num_splits=self.num_splits,
1176
1208
  )
1177
1209
  o, _ = merge_state_v2(
1178
1210
  o,
@@ -1453,7 +1485,7 @@ class FlashAttentionBackend(AttentionBackend):
1453
1485
  seq_lens: torch.Tensor,
1454
1486
  encoder_lens: Optional[torch.Tensor],
1455
1487
  forward_mode: ForwardMode,
1456
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1488
+ spec_info: Optional[SpecInput],
1457
1489
  ):
1458
1490
  """Initialize forward metadata for capturing CUDA graph."""
1459
1491
  metadata = FlashAttentionMetadata()
@@ -1688,7 +1720,7 @@ class FlashAttentionBackend(AttentionBackend):
1688
1720
  seq_lens_sum: int,
1689
1721
  encoder_lens: Optional[torch.Tensor],
1690
1722
  forward_mode: ForwardMode,
1691
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1723
+ spec_info: Optional[SpecInput],
1692
1724
  seq_lens_cpu: Optional[torch.Tensor],
1693
1725
  out_cache_loc: Optional[torch.Tensor] = None,
1694
1726
  ):
@@ -2306,7 +2338,7 @@ class FlashAttentionMultiStepBackend:
2306
2338
  forward_batch: ForwardBatch,
2307
2339
  ):
2308
2340
  assert forward_batch.spec_info is not None
2309
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
2341
+ assert forward_batch.spec_info.is_draft_input()
2310
2342
 
2311
2343
  for i in range(self.speculative_num_steps - 1):
2312
2344
  self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
@@ -2323,7 +2355,7 @@ class FlashAttentionMultiStepBackend:
2323
2355
  self, forward_batch: ForwardBatch, bs: int
2324
2356
  ):
2325
2357
  assert forward_batch.spec_info is not None
2326
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
2358
+ assert forward_batch.spec_info.is_draft_input()
2327
2359
 
2328
2360
  for i in range(self.speculative_num_steps - 1):
2329
2361
  # TODO: incrementally update the metadata for the later steps,