sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,6 @@
5
5
  # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
6
6
  # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
7
7
 
8
- import math
9
8
 
10
9
  import torch
11
10
  import torch.nn.functional as F
@@ -13,6 +12,8 @@ import triton
13
12
  import triton.language as tl
14
13
  from einops import rearrange
15
14
 
15
+ from sglang.srt.utils import device_context
16
+
16
17
 
17
18
  def rms_norm_ref(
18
19
  x,
@@ -158,7 +159,7 @@ def _layer_norm_fwd(
158
159
  # heuristics for number of warps
159
160
  num_warps = min(max(BLOCK_N // 256, 1), 8)
160
161
  grid = (M, ngroups)
161
- with torch.get_device_module(x.device).device(x.device.index):
162
+ with device_context(x.device):
162
163
  _layer_norm_fwd_1pass_kernel[grid](
163
164
  x,
164
165
  out,
@@ -181,6 +182,45 @@ def _layer_norm_fwd(
181
182
  return out, mean, rstd
182
183
 
183
184
 
185
+ def rms_norm_gated(
186
+ *,
187
+ x,
188
+ weight,
189
+ bias,
190
+ z=None,
191
+ eps=1e-6,
192
+ group_size=None,
193
+ norm_before_gate=True,
194
+ is_rms_norm=False,
195
+ ):
196
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
197
+
198
+ x_shape_og = x.shape
199
+ # reshape input data into 2D tensor
200
+ x = x.reshape(-1, x.shape[-1])
201
+ if x.stride(-1) != 1:
202
+ x = x.contiguous()
203
+ if z is not None:
204
+ assert z.shape == x_shape_og
205
+ z = z.reshape(-1, z.shape[-1])
206
+ if z.stride(-1) != 1:
207
+ z = z.contiguous()
208
+ weight = weight.contiguous()
209
+ if bias is not None:
210
+ bias = bias.contiguous()
211
+ y, mean, rstd = _layer_norm_fwd(
212
+ x,
213
+ weight,
214
+ bias,
215
+ eps,
216
+ z=z,
217
+ group_size=group_size,
218
+ norm_before_gate=norm_before_gate,
219
+ is_rms_norm=is_rms_norm,
220
+ )
221
+ return y.reshape(x_shape_og)
222
+
223
+
184
224
  class LayerNormFn(torch.autograd.Function):
185
225
 
186
226
  @staticmethod
@@ -195,32 +235,16 @@ class LayerNormFn(torch.autograd.Function):
195
235
  norm_before_gate=True,
196
236
  is_rms_norm=False,
197
237
  ):
198
- """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
199
-
200
- x_shape_og = x.shape
201
- # reshape input data into 2D tensor
202
- x = x.reshape(-1, x.shape[-1])
203
- if x.stride(-1) != 1:
204
- x = x.contiguous()
205
- if z is not None:
206
- assert z.shape == x_shape_og
207
- z = z.reshape(-1, z.shape[-1])
208
- if z.stride(-1) != 1:
209
- z = z.contiguous()
210
- weight = weight.contiguous()
211
- if bias is not None:
212
- bias = bias.contiguous()
213
- y, mean, rstd = _layer_norm_fwd(
214
- x,
215
- weight,
216
- bias,
217
- eps,
238
+ return rms_norm_gated(
239
+ x=x,
240
+ weight=weight,
241
+ bias=bias,
242
+ eps=eps,
218
243
  z=z,
219
244
  group_size=group_size,
220
245
  norm_before_gate=norm_before_gate,
221
246
  is_rms_norm=is_rms_norm,
222
247
  )
223
- return y.reshape(x_shape_og)
224
248
 
225
249
 
226
250
  def layernorm_fn(
@@ -238,14 +262,6 @@ def layernorm_fn(
238
262
  )
239
263
 
240
264
 
241
- def rmsnorm_fn(
242
- x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
243
- ):
244
- return LayerNormFn.apply(
245
- x, weight, bias, z, eps, group_size, norm_before_gate, True
246
- )
247
-
248
-
249
265
  class LayerNorm(torch.nn.Module):
250
266
 
251
267
  def __init__(
@@ -284,6 +300,7 @@ class LayerNorm(torch.nn.Module):
284
300
  group_size=self.group_size,
285
301
  eps=self.eps,
286
302
  norm_before_gate=self.norm_before_gate,
303
+ is_rms_norm=False,
287
304
  )
288
305
 
289
306
 
@@ -315,7 +332,7 @@ class RMSNorm(torch.nn.Module):
315
332
 
316
333
  def forward(self, x, z=None):
317
334
  """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
318
- return rmsnorm_fn(
335
+ return layernorm_fn(
319
336
  x,
320
337
  self.weight,
321
338
  self.bias,
@@ -323,4 +340,5 @@ class RMSNorm(torch.nn.Module):
323
340
  eps=self.eps,
324
341
  group_size=self.group_size,
325
342
  norm_before_gate=self.norm_before_gate,
343
+ is_rms_norm=True,
326
344
  )
@@ -58,9 +58,6 @@ def check_environments():
58
58
  return None
59
59
 
60
60
 
61
- check_environments()
62
-
63
-
64
61
  def get_abs_err(x, y):
65
62
  return (x.detach() - y.detach()).flatten().abs().max().item()
66
63
 
@@ -9,8 +9,6 @@ import triton
9
9
  import triton.language as tl
10
10
 
11
11
  from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
12
- from sglang.srt.layers.attention.fla.op import safe_exp
13
- from sglang.srt.layers.attention.fla.utils import check_shared_mem
14
12
 
15
13
 
16
14
  @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, Optional, Union
4
+ from typing import TYPE_CHECKING, Optional
5
5
 
6
6
  import numpy as np
7
7
  import torch
@@ -10,10 +10,10 @@ import triton.language as tl
10
10
 
11
11
  from sglang.srt.configs.model_config import AttentionArch
12
12
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
13
- from sglang.srt.managers.schedule_batch import global_server_args_dict
14
- from sglang.srt.mem_cache.memory_pool import SWAKVPool
13
+ from sglang.srt.layers.radix_attention import AttentionType
15
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
16
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
15
+ from sglang.srt.server_args import get_global_server_args
16
+ from sglang.srt.speculative.spec_info import SpecInput
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -305,6 +305,7 @@ class FlashAttentionBackend(AttentionBackend):
305
305
  speculative_step_id=0,
306
306
  topk=0,
307
307
  speculative_num_steps=0,
308
+ fa_impl_ver=3,
308
309
  ):
309
310
  super().__init__()
310
311
 
@@ -338,6 +339,8 @@ class FlashAttentionBackend(AttentionBackend):
338
339
  )
339
340
  self.speculative_step_id = speculative_step_id
340
341
 
342
+ self.fa_impl_ver = fa_impl_ver
343
+
341
344
  # Local attention settings
342
345
  self.attention_chunk_size = (
343
346
  model_runner.attention_chunk_size
@@ -352,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
352
355
  self.sliding_window_size is not None and self.sliding_window_size > -1
353
356
  )
354
357
 
358
+ # If num_splits == 0, we use a heuristic to automatically determine the number of splits.
359
+ # We set nums splits to 1 if deterministic inference is enabled.
360
+ # See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
361
+ self.num_splits = (
362
+ 1 if model_runner.server_args.enable_deterministic_inference else 0
363
+ )
364
+
355
365
  def init_forward_metadata(self, forward_batch: ForwardBatch):
356
366
  """Initialize forward metadata hence all layers in the forward pass can reuse it."""
357
367
  metadata = FlashAttentionMetadata()
@@ -682,8 +692,13 @@ class FlashAttentionBackend(AttentionBackend):
682
692
  k_descale, v_descale = None, None
683
693
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
684
694
  # has corresponding quantization method so that layer.k_scale is not None,
685
- # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
686
- if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
695
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
696
+ # 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
697
+ if (
698
+ self.kv_cache_dtype_str != "auto"
699
+ and layer.head_dim <= 256
700
+ and self.fa_impl_ver != 4
701
+ ):
687
702
  if layer.k_scale is not None:
688
703
  descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
689
704
  k_descale = layer.k_scale.expand(descale_shape)
@@ -691,7 +706,9 @@ class FlashAttentionBackend(AttentionBackend):
691
706
  q = q.to(self.kv_cache_dtype)
692
707
  q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
693
708
  k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
694
- causal = not layer.is_cross_attention
709
+ causal = True
710
+ if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
711
+ causal = False
695
712
 
696
713
  # Check if we should use local attention
697
714
  use_local_attn = (
@@ -712,6 +729,8 @@ class FlashAttentionBackend(AttentionBackend):
712
729
 
713
730
  # For fa3 interface version compatibility, we put new fields into conditional keyword args
714
731
  kwargs = {}
732
+ if self.fa_impl_ver != 3:
733
+ kwargs["ver"] = self.fa_impl_ver
715
734
  if sinks is not None:
716
735
  kwargs["sinks"] = sinks
717
736
 
@@ -770,6 +789,7 @@ class FlashAttentionBackend(AttentionBackend):
770
789
  k_descale=k_descale,
771
790
  v_descale=v_descale,
772
791
  return_softmax_lse=use_cascade_attn,
792
+ num_splits=self.num_splits,
773
793
  **kwargs,
774
794
  )
775
795
 
@@ -791,6 +811,7 @@ class FlashAttentionBackend(AttentionBackend):
791
811
  k_descale=k_descale,
792
812
  v_descale=v_descale,
793
813
  return_softmax_lse=True,
814
+ num_splits=self.num_splits,
794
815
  **kwargs,
795
816
  )
796
817
  o, _ = merge_state_v2_wrapper(
@@ -809,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
809
830
  ):
810
831
  # Do multi-head attention with chunked prefix cache
811
832
  if forward_batch.attn_attend_prefix_cache:
812
- assert not global_server_args_dict["disable_chunked_prefix_cache"]
833
+ assert not get_global_server_args().disable_chunked_prefix_cache
813
834
  # MHA for chunked prefix kv cache when running model with MLA
814
835
  assert forward_batch.prefix_chunk_idx is not None
815
836
  assert forward_batch.prefix_chunk_cu_seq_lens is not None
@@ -830,6 +851,7 @@ class FlashAttentionBackend(AttentionBackend):
830
851
  softmax_scale=layer.scaling,
831
852
  causal=False,
832
853
  return_softmax_lse=True,
854
+ **kwargs,
833
855
  )
834
856
  else:
835
857
  # MHA for extend part of sequence without attending prefix kv cache
@@ -844,6 +866,7 @@ class FlashAttentionBackend(AttentionBackend):
844
866
  softmax_scale=layer.scaling,
845
867
  causal=True,
846
868
  return_softmax_lse=forward_batch.mha_return_lse,
869
+ **kwargs,
847
870
  )
848
871
  if forward_batch.mha_return_lse:
849
872
  output, lse, *rest = output
@@ -851,6 +874,7 @@ class FlashAttentionBackend(AttentionBackend):
851
874
  return output, lse
852
875
  return output
853
876
  else:
877
+ assert self.fa_impl_ver in [3], "Only FA3 support here"
854
878
  # Do absorbed multi-latent attention
855
879
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
856
880
  layer.layer_id
@@ -892,6 +916,7 @@ class FlashAttentionBackend(AttentionBackend):
892
916
  k_descale=k_descale,
893
917
  v_descale=v_descale,
894
918
  return_softmax_lse=use_cascade_attn,
919
+ num_splits=self.num_splits,
895
920
  )
896
921
  if use_cascade_attn:
897
922
  o, softmax_lse, *rest = result
@@ -913,6 +938,7 @@ class FlashAttentionBackend(AttentionBackend):
913
938
  k_descale=k_descale,
914
939
  v_descale=v_descale,
915
940
  return_softmax_lse=True,
941
+ num_splits=self.num_splits,
916
942
  )
917
943
  )
918
944
  o, _ = merge_state_v2_wrapper(
@@ -939,6 +965,7 @@ class FlashAttentionBackend(AttentionBackend):
939
965
  k_rope: Optional[torch.Tensor] = None,
940
966
  sinks: Optional[torch.Tensor] = None,
941
967
  ) -> torch.Tensor:
968
+ assert self.fa_impl_ver in [3], "Only FA3 support decoding"
942
969
  if k is not None:
943
970
  assert v is not None
944
971
  if save_kv_cache:
@@ -981,10 +1008,14 @@ class FlashAttentionBackend(AttentionBackend):
981
1008
  if layer.sliding_window_size is not None and layer.sliding_window_size > -1
982
1009
  else (-1, -1)
983
1010
  )
984
- causal = not layer.is_cross_attention
1011
+ causal = True
1012
+ if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
1013
+ causal = False
985
1014
 
986
1015
  # For fa3 interface version compatibility, we put new fields into conditional keyword args
987
1016
  kwargs = {}
1017
+ if self.fa_impl_ver != 3:
1018
+ kwargs["ver"] = self.fa_impl_ver
988
1019
  if sinks is not None:
989
1020
  kwargs["sinks"] = sinks
990
1021
 
@@ -1030,6 +1061,7 @@ class FlashAttentionBackend(AttentionBackend):
1030
1061
  softcap=layer.logit_cap,
1031
1062
  k_descale=k_descale,
1032
1063
  v_descale=v_descale,
1064
+ num_splits=self.num_splits,
1033
1065
  **kwargs,
1034
1066
  )
1035
1067
  elif use_local_attn:
@@ -1049,6 +1081,7 @@ class FlashAttentionBackend(AttentionBackend):
1049
1081
  softcap=layer.logit_cap,
1050
1082
  k_descale=k_descale,
1051
1083
  v_descale=v_descale,
1084
+ num_splits=self.num_splits,
1052
1085
  **kwargs,
1053
1086
  )
1054
1087
  else:
@@ -1077,6 +1110,7 @@ class FlashAttentionBackend(AttentionBackend):
1077
1110
  k_descale=k_descale,
1078
1111
  v_descale=v_descale,
1079
1112
  return_softmax_lse=use_cascade_attn,
1113
+ num_splits=self.num_splits,
1080
1114
  **kwargs,
1081
1115
  )
1082
1116
  if use_cascade_attn:
@@ -1098,6 +1132,7 @@ class FlashAttentionBackend(AttentionBackend):
1098
1132
  k_descale=k_descale,
1099
1133
  v_descale=v_descale,
1100
1134
  return_softmax_lse=True,
1135
+ num_splits=self.num_splits,
1101
1136
  **kwargs,
1102
1137
  )
1103
1138
  )
@@ -1153,6 +1188,7 @@ class FlashAttentionBackend(AttentionBackend):
1153
1188
  k_descale=k_descale,
1154
1189
  v_descale=v_descale,
1155
1190
  return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
1191
+ num_splits=self.num_splits,
1156
1192
  )
1157
1193
  if use_cascade_attn:
1158
1194
  o, softmax_lse, *rest = result
@@ -1173,6 +1209,7 @@ class FlashAttentionBackend(AttentionBackend):
1173
1209
  k_descale=k_descale,
1174
1210
  v_descale=v_descale,
1175
1211
  return_softmax_lse=True,
1212
+ num_splits=self.num_splits,
1176
1213
  )
1177
1214
  o, _ = merge_state_v2(
1178
1215
  o,
@@ -1453,7 +1490,7 @@ class FlashAttentionBackend(AttentionBackend):
1453
1490
  seq_lens: torch.Tensor,
1454
1491
  encoder_lens: Optional[torch.Tensor],
1455
1492
  forward_mode: ForwardMode,
1456
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1493
+ spec_info: Optional[SpecInput],
1457
1494
  ):
1458
1495
  """Initialize forward metadata for capturing CUDA graph."""
1459
1496
  metadata = FlashAttentionMetadata()
@@ -1688,7 +1725,7 @@ class FlashAttentionBackend(AttentionBackend):
1688
1725
  seq_lens_sum: int,
1689
1726
  encoder_lens: Optional[torch.Tensor],
1690
1727
  forward_mode: ForwardMode,
1691
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1728
+ spec_info: Optional[SpecInput],
1692
1729
  seq_lens_cpu: Optional[torch.Tensor],
1693
1730
  out_cache_loc: Optional[torch.Tensor] = None,
1694
1731
  ):
@@ -2283,7 +2320,7 @@ class FlashAttentionMultiStepBackend:
2283
2320
  self.topk = topk
2284
2321
  self.speculative_num_steps = speculative_num_steps
2285
2322
  self.attn_backends = []
2286
- for i in range(self.speculative_num_steps):
2323
+ for i in range(self.speculative_num_steps - 1):
2287
2324
  self.attn_backends.append(
2288
2325
  FlashAttentionBackend(
2289
2326
  model_runner,
@@ -2298,7 +2335,7 @@ class FlashAttentionMultiStepBackend:
2298
2335
  self.attn_backends[i].init_forward_metadata(forward_batch)
2299
2336
 
2300
2337
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
2301
- for i in range(self.speculative_num_steps):
2338
+ for i in range(self.speculative_num_steps - 1):
2302
2339
  self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
2303
2340
 
2304
2341
  def init_forward_metadata_capture_cuda_graph(
@@ -2306,7 +2343,7 @@ class FlashAttentionMultiStepBackend:
2306
2343
  forward_batch: ForwardBatch,
2307
2344
  ):
2308
2345
  assert forward_batch.spec_info is not None
2309
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
2346
+ assert forward_batch.spec_info.is_draft_input()
2310
2347
 
2311
2348
  for i in range(self.speculative_num_steps - 1):
2312
2349
  self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
@@ -2323,7 +2360,7 @@ class FlashAttentionMultiStepBackend:
2323
2360
  self, forward_batch: ForwardBatch, bs: int
2324
2361
  ):
2325
2362
  assert forward_batch.spec_info is not None
2326
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
2363
+ assert forward_batch.spec_info.is_draft_input()
2327
2364
 
2328
2365
  for i in range(self.speculative_num_steps - 1):
2329
2366
  # TODO: incrementally update the metadata for the later steps,