sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -32,12 +32,182 @@ if _is_cuda:
32
32
  _is_hip = is_hip()
33
33
 
34
34
 
35
+ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
36
+ """
37
+ Get block sizes and configuration for extend attention kernels.
38
+
39
+ Args:
40
+ Lq: Query head dimension
41
+ Lv: Value head dimension
42
+
43
+ Returns:
44
+ tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps)
45
+ """
46
+ # Determine BLOCK_DMODEL and BLOCK_DPE based on head dimension
47
+ if Lq == 576:
48
+ BLOCK_DMODEL = 512
49
+ BLOCK_DPE = 64
50
+ elif Lq == 288:
51
+ BLOCK_DMODEL = 256
52
+ BLOCK_DPE = 32
53
+ elif Lq == 192:
54
+ BLOCK_DMODEL = 128
55
+ BLOCK_DPE = 64
56
+ else:
57
+ BLOCK_DMODEL = triton.next_power_of_2(Lq)
58
+ BLOCK_DPE = 0
59
+
60
+ BLOCK_DV = triton.next_power_of_2(Lv)
61
+
62
+ # Determine BLOCK_M, BLOCK_N, and num_warps based on hardware
63
+ if _is_hip:
64
+ BLOCK_M, BLOCK_N = (64, 64)
65
+ num_warps = 4
66
+ else:
67
+ if _is_cuda and CUDA_CAPABILITY[0] >= 9:
68
+ # Hopper architecture (H100, etc.)
69
+ if Lq <= 256:
70
+ BLOCK_M, BLOCK_N = (128, 64)
71
+ else:
72
+ BLOCK_M, BLOCK_N = (32, 64)
73
+ elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
74
+ # Ampere architecture (A100, etc.)
75
+ # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
76
+ if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
77
+ if Lq <= 128:
78
+ BLOCK_M, BLOCK_N = (64, 128)
79
+ elif Lq <= 256:
80
+ BLOCK_M, BLOCK_N = (64, 64)
81
+ else:
82
+ BLOCK_M, BLOCK_N = (32, 32)
83
+ else:
84
+ if Lq <= 128:
85
+ BLOCK_M, BLOCK_N = (128, 128)
86
+ elif Lq <= 256:
87
+ BLOCK_M, BLOCK_N = (64, 64)
88
+ else:
89
+ BLOCK_M, BLOCK_N = (32, 64)
90
+ else:
91
+ # Older architectures
92
+ BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
93
+
94
+ num_warps = 4 if Lq <= 64 else 8
95
+
96
+ return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps
97
+
98
+
35
99
  @triton.jit
36
100
  def tanh(x):
37
101
  # Tanh is just a scaled sigmoid
38
102
  return 2 * tl.sigmoid(2 * x) - 1
39
103
 
40
104
 
105
+ @triton.jit
106
+ def _copy_unified_indices_kernel(
107
+ # Input buffers
108
+ prefix_kv_indptr,
109
+ prefix_kv_indices,
110
+ extend_start_loc,
111
+ extend_seq_lens,
112
+ extend_kv_indices,
113
+ unified_kv_indptr,
114
+ # Output buffer
115
+ unified_kv_indices,
116
+ # Size
117
+ bs,
118
+ ):
119
+ """
120
+ Triton kernel to copy indices to unified buffer (parallel per sequence).
121
+ Each thread block processes one sequence with vectorized loads/stores.
122
+ """
123
+ pid = tl.program_id(0)
124
+
125
+ if pid >= bs:
126
+ return
127
+
128
+ # Load sequence info
129
+ prefix_start = tl.load(prefix_kv_indptr + pid)
130
+ prefix_end = tl.load(prefix_kv_indptr + pid + 1)
131
+ extend_start = tl.load(extend_start_loc + pid)
132
+ extend_len = tl.load(extend_seq_lens + pid)
133
+
134
+ prefix_len = prefix_end - prefix_start
135
+ unified_start = tl.load(unified_kv_indptr + pid)
136
+
137
+ # Copy indices in vectorized chunks
138
+ BLOCK_SIZE: tl.constexpr = 128
139
+
140
+ # Process prefix indices
141
+ for block_start in range(0, prefix_len, BLOCK_SIZE):
142
+ offs = block_start + tl.arange(0, BLOCK_SIZE)
143
+ mask = offs < prefix_len
144
+
145
+ src_idx = prefix_start + offs
146
+ dst_idx = unified_start + offs
147
+
148
+ vals = tl.load(prefix_kv_indices + src_idx, mask=mask, other=0)
149
+ tl.store(unified_kv_indices + dst_idx, vals, mask=mask)
150
+
151
+ # Process extend indices
152
+ for block_start in range(0, extend_len, BLOCK_SIZE):
153
+ offs = block_start + tl.arange(0, BLOCK_SIZE)
154
+ mask = offs < extend_len
155
+
156
+ src_idx = extend_start + offs
157
+ dst_idx = unified_start + prefix_len + offs
158
+
159
+ vals = tl.load(extend_kv_indices + src_idx, mask=mask, other=0)
160
+ tl.store(unified_kv_indices + dst_idx, vals, mask=mask)
161
+
162
+
163
+ def build_unified_kv_indices(
164
+ prefix_kv_indptr: torch.Tensor,
165
+ prefix_kv_indices: torch.Tensor,
166
+ extend_start_loc: torch.Tensor,
167
+ extend_seq_lens: torch.Tensor,
168
+ extend_kv_indices: torch.Tensor,
169
+ bs: int,
170
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ """
172
+ Build unified KV indices efficiently:
173
+ - Use PyTorch's optimized cumsum (NVIDIA CUB) for indptr
174
+ - Use Triton kernel for parallel index copying
175
+
176
+ Returns:
177
+ (unified_kv_indptr, unified_kv_indices, prefix_lens)
178
+ """
179
+ device = prefix_kv_indptr.device
180
+
181
+ prefix_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]
182
+
183
+ # Create unified_kv_indptr avoiding direct assignment (for CUDA graph compatibility)
184
+ unified_lens = prefix_lens + extend_seq_lens[:bs]
185
+ unified_kv_indptr = torch.cat(
186
+ [
187
+ torch.zeros(1, dtype=torch.int32, device=device),
188
+ torch.cumsum(unified_lens, dim=0),
189
+ ]
190
+ )
191
+
192
+ max_unified_len = len(prefix_kv_indices) + len(extend_kv_indices)
193
+
194
+ unified_kv_indices = torch.empty(max_unified_len, dtype=torch.int64, device=device)
195
+
196
+ # Launch Triton kernel for parallel index copying
197
+ _copy_unified_indices_kernel[(bs,)](
198
+ prefix_kv_indptr,
199
+ prefix_kv_indices,
200
+ extend_start_loc,
201
+ extend_seq_lens,
202
+ extend_kv_indices,
203
+ unified_kv_indptr,
204
+ unified_kv_indices,
205
+ bs,
206
+ )
207
+
208
+ return unified_kv_indptr, unified_kv_indices, prefix_lens
209
+
210
+
41
211
  @triton.jit
42
212
  def _fwd_kernel(
43
213
  Q_Extend,
@@ -402,50 +572,10 @@ def extend_attention_fwd(
402
572
  v_extend.shape[-1],
403
573
  )
404
574
 
405
- if Lq == 576:
406
- BLOCK_DMODEL = 512
407
- BLOCK_DPE = 64
408
- elif Lq == 288:
409
- BLOCK_DMODEL = 256
410
- BLOCK_DPE = 32
411
- elif Lq == 192:
412
- BLOCK_DMODEL = 128
413
- BLOCK_DPE = 64
414
- else:
415
- BLOCK_DMODEL = triton.next_power_of_2(Lq)
416
- BLOCK_DPE = 0
417
- BLOCK_DV = triton.next_power_of_2(Lv)
418
-
419
- if _is_hip:
420
- BLOCK_M, BLOCK_N = (64, 64)
421
- num_warps = 4
422
-
423
- else:
424
- if _is_cuda and CUDA_CAPABILITY[0] >= 9:
425
- if Lq <= 256:
426
- BLOCK_M, BLOCK_N = (128, 64)
427
- else:
428
- BLOCK_M, BLOCK_N = (32, 64)
429
- elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
430
- # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
431
- if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
432
- if Lq <= 128:
433
- BLOCK_M, BLOCK_N = (64, 128)
434
- elif Lq <= 256:
435
- BLOCK_M, BLOCK_N = (64, 64)
436
- else:
437
- BLOCK_M, BLOCK_N = (32, 32)
438
- else:
439
- if Lq <= 128:
440
- BLOCK_M, BLOCK_N = (128, 128)
441
- elif Lq <= 256:
442
- BLOCK_M, BLOCK_N = (64, 64)
443
- else:
444
- BLOCK_M, BLOCK_N = (32, 64)
445
- else:
446
- BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
447
-
448
- num_warps = 4 if Lk <= 64 else 8
575
+ # Get block sizes and configuration
576
+ BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
577
+ _get_block_sizes_for_extend_attention(Lq, Lv)
578
+ )
449
579
 
450
580
  sm_scale = sm_scale or 1.0 / (Lq**0.5)
451
581
  batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
@@ -548,3 +678,368 @@ def redundant_attention(
548
678
  pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
549
679
  o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
550
680
  pt += cur_seq_len_extend
681
+
682
+
683
+ @triton.jit
684
+ def _fwd_kernel_unified(
685
+ Q,
686
+ O,
687
+ K_Buffer,
688
+ V_Buffer,
689
+ qo_indptr,
690
+ kv_indptr,
691
+ kv_indices,
692
+ prefix_lens,
693
+ mask_ptr,
694
+ mask_indptr,
695
+ sink_ptr,
696
+ window_start_pos,
697
+ sm_scale,
698
+ kv_group_num,
699
+ stride_qbs,
700
+ stride_qh,
701
+ stride_obs,
702
+ stride_oh,
703
+ stride_buf_kbs,
704
+ stride_buf_kh,
705
+ stride_buf_vbs,
706
+ stride_buf_vh,
707
+ SLIDING_WINDOW_SIZE: tl.constexpr,
708
+ logit_cap: tl.constexpr,
709
+ xai_temperature_len: tl.constexpr,
710
+ Lq: tl.constexpr,
711
+ Lv: tl.constexpr,
712
+ BLOCK_DMODEL: tl.constexpr,
713
+ BLOCK_DPE: tl.constexpr,
714
+ BLOCK_DV: tl.constexpr,
715
+ BLOCK_M: tl.constexpr,
716
+ BLOCK_N: tl.constexpr,
717
+ IS_CAUSAL: tl.constexpr,
718
+ USE_CUSTOM_MASK: tl.constexpr,
719
+ HAS_SINK: tl.constexpr,
720
+ ):
721
+ """
722
+ Unified 1-stage kernel for deterministic extend attention.
723
+ Both prefix and extend KV are accessed through the unified kv_indices.
724
+ """
725
+ cur_seq = tl.program_id(0)
726
+ cur_head = tl.program_id(1)
727
+ cur_block_m = tl.program_id(2)
728
+ cur_kv_head = cur_head // kv_group_num
729
+
730
+ # Load sequence information
731
+ cur_seq_q_start_idx = tl.load(qo_indptr + cur_seq)
732
+ cur_seq_q_len = tl.load(qo_indptr + cur_seq + 1) - cur_seq_q_start_idx
733
+ cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
734
+ cur_seq_kv_len = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
735
+ cur_seq_prefix_len = tl.load(prefix_lens + cur_seq)
736
+
737
+ # Load window start position for sliding window attention
738
+ # This is the absolute position of the first key in the window (0 if no sliding window)
739
+ cur_window_start = 0
740
+ if SLIDING_WINDOW_SIZE > 0:
741
+ cur_window_start = tl.load(window_start_pos + cur_seq)
742
+
743
+ # Load custom mask start index if using custom mask (for speculative decoding)
744
+ if USE_CUSTOM_MASK:
745
+ cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
746
+
747
+ offs_d = tl.arange(0, BLOCK_DMODEL)
748
+ offs_dv = tl.arange(0, BLOCK_DV)
749
+ offs_m = tl.arange(0, BLOCK_M)
750
+ mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_q_len
751
+ mask_d = offs_d < Lq
752
+ mask_dv = offs_dv < Lv
753
+
754
+ # XAI temperature handling
755
+ if xai_temperature_len > 0:
756
+ offs_qidx = cur_seq_prefix_len + cur_block_m * BLOCK_M + offs_m
757
+ xai_temperature_reg = tl.where(
758
+ offs_qidx < xai_temperature_len,
759
+ 1.0,
760
+ xai_temperature_len / (offs_qidx + 1.0),
761
+ )
762
+
763
+ # Load Q
764
+ offs_q = (
765
+ (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs
766
+ + cur_head * stride_qh
767
+ + offs_d[None, :]
768
+ )
769
+ q = tl.load(Q + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0)
770
+
771
+ if BLOCK_DPE > 0:
772
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
773
+ offs_qpe = (
774
+ (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs
775
+ + cur_head * stride_qh
776
+ + offs_dpe[None, :]
777
+ )
778
+ qpe = tl.load(Q + offs_qpe, mask=mask_m[:, None], other=0.0)
779
+
780
+ # Initialize accumulators
781
+ offs_n = tl.arange(0, BLOCK_N)
782
+ acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
783
+ deno = tl.zeros([BLOCK_M], dtype=tl.float32)
784
+ e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
785
+
786
+ # Unified loop: process all KV tokens (prefix + extend)
787
+ for start_n in range(0, cur_seq_kv_len, BLOCK_N):
788
+ start_n = tl.multiple_of(start_n, BLOCK_N)
789
+ mask_n = (start_n + offs_n) < cur_seq_kv_len
790
+
791
+ # Compute mask
792
+ final_mask = mask_m[:, None] & mask_n[None, :]
793
+
794
+ # Apply custom mask if provided
795
+ if USE_CUSTOM_MASK:
796
+ custom_mask = tl.load(
797
+ mask_ptr
798
+ + cur_seq_mask_start_idx
799
+ + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_kv_len
800
+ + start_n
801
+ + offs_n[None, :],
802
+ mask=(mask_m[:, None] & mask_n[None, :]),
803
+ other=0,
804
+ )
805
+ final_mask &= custom_mask
806
+
807
+ # Apply causal mask for extend part
808
+ if IS_CAUSAL and not USE_CUSTOM_MASK:
809
+ # Determine if current KV block is in extend region
810
+ # Only apply causal mask when both Q and K are in extend region
811
+ q_idx = cur_block_m * BLOCK_M + offs_m[:, None]
812
+ k_idx_in_total = start_n + offs_n[None, :]
813
+
814
+ # Causal mask: q_idx >= (k_idx - prefix_len) when k_idx >= prefix_len
815
+ # For prefix region (k_idx < prefix_len), no causal mask
816
+ k_is_extend = k_idx_in_total >= cur_seq_prefix_len
817
+ k_idx_in_extend = k_idx_in_total - cur_seq_prefix_len
818
+ causal_mask = tl.where(
819
+ k_is_extend,
820
+ q_idx >= k_idx_in_extend,
821
+ True, # No causal mask for prefix
822
+ )
823
+ final_mask &= causal_mask
824
+
825
+ if SLIDING_WINDOW_SIZE > 0:
826
+ # Sliding window mask with correct absolute positions
827
+ # Q absolute position: window_start + prefix_len + q_position_in_extend
828
+ q_abs_pos = (
829
+ cur_window_start
830
+ + cur_seq_prefix_len
831
+ + cur_block_m * BLOCK_M
832
+ + offs_m[:, None]
833
+ )
834
+
835
+ # K absolute position: window_start + k_index_in_unified_array
836
+ k_abs_pos = cur_window_start + start_n + offs_n[None, :]
837
+
838
+ # Sliding window: query can attend to keys within window_size
839
+ window_mask = q_abs_pos <= (k_abs_pos + SLIDING_WINDOW_SIZE)
840
+ final_mask &= window_mask
841
+
842
+ # Check if we can skip this tile
843
+ SKIP_TILE = False
844
+ if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
845
+ SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
846
+
847
+ if not SKIP_TILE:
848
+ # Load KV indices
849
+ offs_kv_loc = tl.load(
850
+ kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
851
+ mask=mask_n,
852
+ other=0,
853
+ )
854
+
855
+ # Load K
856
+ offs_buf_k = (
857
+ offs_kv_loc[None, :] * stride_buf_kbs
858
+ + cur_kv_head * stride_buf_kh
859
+ + offs_d[:, None]
860
+ )
861
+ k = tl.load(
862
+ K_Buffer + offs_buf_k,
863
+ mask=(mask_n[None, :]) & (mask_d[:, None]),
864
+ other=0.0,
865
+ )
866
+
867
+ # Compute QK
868
+ qk = tl.dot(q.to(k.dtype), k)
869
+ if BLOCK_DPE > 0:
870
+ offs_kpe = (
871
+ offs_kv_loc[None, :] * stride_buf_kbs
872
+ + cur_kv_head * stride_buf_kh
873
+ + offs_dpe[:, None]
874
+ )
875
+ kpe = tl.load(
876
+ K_Buffer + offs_kpe,
877
+ mask=mask_n[None, :],
878
+ other=0.0,
879
+ )
880
+ qk += tl.dot(qpe.to(kpe.dtype), kpe)
881
+
882
+ qk *= sm_scale
883
+
884
+ if logit_cap > 0:
885
+ qk = logit_cap * tanh(qk / logit_cap)
886
+
887
+ if xai_temperature_len > 0:
888
+ qk *= xai_temperature_reg[:, None]
889
+
890
+ qk = tl.where(final_mask, qk, float("-inf"))
891
+
892
+ # Online softmax
893
+ row_max = tl.max(qk, 1)
894
+ row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
895
+ n_e_max = tl.maximum(row_max_fixed, e_max)
896
+
897
+ re_scale = tl.exp(e_max - n_e_max)
898
+ p = tl.exp(qk - n_e_max[:, None])
899
+ deno = deno * re_scale + tl.sum(p, 1)
900
+
901
+ # Load V
902
+ offs_buf_v = (
903
+ offs_kv_loc[:, None] * stride_buf_vbs
904
+ + cur_kv_head * stride_buf_vh
905
+ + offs_dv[None, :]
906
+ )
907
+ v = tl.load(
908
+ V_Buffer + offs_buf_v,
909
+ mask=mask_n[:, None] & mask_dv[None, :],
910
+ other=0.0,
911
+ )
912
+ p = p.to(v.dtype)
913
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
914
+
915
+ e_max = n_e_max
916
+
917
+ # Handle sink tokens
918
+ if HAS_SINK:
919
+ cur_sink = tl.load(sink_ptr + cur_head)
920
+ deno += tl.exp(cur_sink - e_max)
921
+
922
+ # Store output
923
+ offs_o = (
924
+ (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs
925
+ + cur_head * stride_oh
926
+ + offs_dv[None, :]
927
+ )
928
+ tl.store(
929
+ O + offs_o,
930
+ acc / deno[:, None],
931
+ mask=mask_m[:, None] & mask_dv[None, :],
932
+ )
933
+
934
+
935
+ def extend_attention_fwd_unified(
936
+ q,
937
+ o,
938
+ k_buffer,
939
+ v_buffer,
940
+ qo_indptr,
941
+ kv_indptr,
942
+ kv_indices,
943
+ prefix_lens,
944
+ max_len_extend,
945
+ custom_mask=None,
946
+ mask_indptr=None,
947
+ sm_scale=None,
948
+ logit_cap=0.0,
949
+ is_causal=True,
950
+ sliding_window_size=-1,
951
+ sinks=None,
952
+ window_start_pos=None,
953
+ xai_temperature_len=-1,
954
+ ):
955
+ """
956
+ Unified 1-stage extend attention for deterministic inference.
957
+
958
+ Args:
959
+ q: Query tensor [num_tokens, num_heads, head_dim]
960
+ o: Output tensor [num_tokens, num_heads, head_dim]
961
+ k_buffer: Key cache buffer
962
+ v_buffer: Value cache buffer
963
+ qo_indptr: Query offsets [batch_size + 1]
964
+ kv_indptr: KV offsets [batch_size + 1] (includes both prefix and extend)
965
+ kv_indices: Unified KV indices (both prefix and extend)
966
+ prefix_lens: Prefix length for each sequence [batch_size]
967
+ max_len_extend: Maximum extend length
968
+ custom_mask: Custom attention mask (for speculative decoding tree attention)
969
+ mask_indptr: Mask offsets [batch_size + 1]
970
+ sm_scale: Softmax scale
971
+ logit_cap: Logit capping value
972
+ is_causal: Whether to apply causal mask
973
+ sliding_window_size: Sliding window size (-1 for no sliding window)
974
+ sinks: Sink tokens
975
+ window_start_pos: Absolute position of first key in sliding window [batch_size]
976
+ (None if sliding window not used)
977
+ xai_temperature_len: XAI temperature length
978
+ """
979
+ Lq, Lv = q.shape[-1], v_buffer.shape[-1]
980
+
981
+ # Get block sizes and configuration
982
+ BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
983
+ _get_block_sizes_for_extend_attention(Lq, Lv)
984
+ )
985
+
986
+ sm_scale = sm_scale or 1.0 / (Lq**0.5)
987
+ batch_size, head_num = qo_indptr.shape[0] - 1, q.shape[1]
988
+ kv_group_num = q.shape[1] // k_buffer.shape[1]
989
+
990
+ USE_CUSTOM_MASK = custom_mask is not None
991
+ HAS_SINK = sinks is not None
992
+
993
+ # For sliding window attention, window_start_pos tracks the absolute position
994
+ # of the first key in each sequence's window
995
+ if sliding_window_size > 0 and window_start_pos is None:
996
+ # If not provided, assume window starts at position 0
997
+ window_start_pos = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
998
+
999
+ grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
1000
+ num_stages = 1
1001
+
1002
+ extra_kargs = {}
1003
+ if _is_hip:
1004
+ extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
1005
+
1006
+ _fwd_kernel_unified[grid](
1007
+ q,
1008
+ o,
1009
+ k_buffer,
1010
+ v_buffer,
1011
+ qo_indptr,
1012
+ kv_indptr,
1013
+ kv_indices,
1014
+ prefix_lens,
1015
+ custom_mask,
1016
+ mask_indptr,
1017
+ sinks,
1018
+ window_start_pos,
1019
+ sm_scale,
1020
+ kv_group_num,
1021
+ q.stride(0),
1022
+ q.stride(1),
1023
+ o.stride(0),
1024
+ o.stride(1),
1025
+ k_buffer.stride(0),
1026
+ k_buffer.stride(1),
1027
+ v_buffer.stride(0),
1028
+ v_buffer.stride(1),
1029
+ SLIDING_WINDOW_SIZE=sliding_window_size,
1030
+ logit_cap=logit_cap,
1031
+ xai_temperature_len=xai_temperature_len,
1032
+ BLOCK_DMODEL=BLOCK_DMODEL,
1033
+ BLOCK_DPE=BLOCK_DPE,
1034
+ BLOCK_DV=BLOCK_DV,
1035
+ BLOCK_M=BLOCK_M,
1036
+ BLOCK_N=BLOCK_N,
1037
+ Lq=Lq,
1038
+ Lv=Lv,
1039
+ IS_CAUSAL=is_causal,
1040
+ USE_CUSTOM_MASK=USE_CUSTOM_MASK,
1041
+ HAS_SINK=HAS_SINK,
1042
+ num_warps=num_warps,
1043
+ num_stages=num_stages,
1044
+ **extra_kargs,
1045
+ )
@@ -637,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
637
637
  self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
638
638
  ):
639
639
  super().__init__(model_runner, topk, speculative_num_steps)
640
- for i in range(speculative_num_steps):
640
+ for i in range(self.speculative_num_steps - 1):
641
641
  self.attn_backends[i] = TRTLLMHAAttnBackend(
642
642
  model_runner,
643
643
  skip_prefill=True,
@@ -651,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
651
651
  self.attn_backends[i].init_forward_metadata(forward_batch)
652
652
 
653
653
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
654
- for i in range(self.speculative_num_steps):
654
+ for i in range(self.speculative_num_steps - 1):
655
655
  self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
656
656
 
657
657
  def init_forward_metadata_capture_cuda_graph(