sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -32,12 +32,182 @@ if _is_cuda:
32
32
  _is_hip = is_hip()
33
33
 
34
34
 
35
+ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
36
+ """
37
+ Get block sizes and configuration for extend attention kernels.
38
+
39
+ Args:
40
+ Lq: Query head dimension
41
+ Lv: Value head dimension
42
+
43
+ Returns:
44
+ tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps)
45
+ """
46
+ # Determine BLOCK_DMODEL and BLOCK_DPE based on head dimension
47
+ if Lq == 576:
48
+ BLOCK_DMODEL = 512
49
+ BLOCK_DPE = 64
50
+ elif Lq == 288:
51
+ BLOCK_DMODEL = 256
52
+ BLOCK_DPE = 32
53
+ elif Lq == 192:
54
+ BLOCK_DMODEL = 128
55
+ BLOCK_DPE = 64
56
+ else:
57
+ BLOCK_DMODEL = triton.next_power_of_2(Lq)
58
+ BLOCK_DPE = 0
59
+
60
+ BLOCK_DV = triton.next_power_of_2(Lv)
61
+
62
+ # Determine BLOCK_M, BLOCK_N, and num_warps based on hardware
63
+ if _is_hip:
64
+ BLOCK_M, BLOCK_N = (64, 64)
65
+ num_warps = 4
66
+ else:
67
+ if _is_cuda and CUDA_CAPABILITY[0] >= 9:
68
+ # Hopper architecture (H100, etc.)
69
+ if Lq <= 256:
70
+ BLOCK_M, BLOCK_N = (128, 64)
71
+ else:
72
+ BLOCK_M, BLOCK_N = (32, 64)
73
+ elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
74
+ # Ampere architecture (A100, etc.)
75
+ # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
76
+ if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
77
+ if Lq <= 128:
78
+ BLOCK_M, BLOCK_N = (64, 128)
79
+ elif Lq <= 256:
80
+ BLOCK_M, BLOCK_N = (64, 64)
81
+ else:
82
+ BLOCK_M, BLOCK_N = (32, 32)
83
+ else:
84
+ if Lq <= 128:
85
+ BLOCK_M, BLOCK_N = (128, 128)
86
+ elif Lq <= 256:
87
+ BLOCK_M, BLOCK_N = (64, 64)
88
+ else:
89
+ BLOCK_M, BLOCK_N = (32, 64)
90
+ else:
91
+ # Older architectures
92
+ BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
93
+
94
+ num_warps = 4 if Lq <= 64 else 8
95
+
96
+ return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps
97
+
98
+
35
99
  @triton.jit
36
100
  def tanh(x):
37
101
  # Tanh is just a scaled sigmoid
38
102
  return 2 * tl.sigmoid(2 * x) - 1
39
103
 
40
104
 
105
+ @triton.jit
106
+ def _copy_unified_indices_kernel(
107
+ # Input buffers
108
+ prefix_kv_indptr,
109
+ prefix_kv_indices,
110
+ extend_start_loc,
111
+ extend_seq_lens,
112
+ extend_kv_indices,
113
+ unified_kv_indptr,
114
+ # Output buffer
115
+ unified_kv_indices,
116
+ # Size
117
+ bs,
118
+ ):
119
+ """
120
+ Triton kernel to copy indices to unified buffer (parallel per sequence).
121
+ Each thread block processes one sequence with vectorized loads/stores.
122
+ """
123
+ pid = tl.program_id(0)
124
+
125
+ if pid >= bs:
126
+ return
127
+
128
+ # Load sequence info
129
+ prefix_start = tl.load(prefix_kv_indptr + pid)
130
+ prefix_end = tl.load(prefix_kv_indptr + pid + 1)
131
+ extend_start = tl.load(extend_start_loc + pid)
132
+ extend_len = tl.load(extend_seq_lens + pid)
133
+
134
+ prefix_len = prefix_end - prefix_start
135
+ unified_start = tl.load(unified_kv_indptr + pid)
136
+
137
+ # Copy indices in vectorized chunks
138
+ BLOCK_SIZE: tl.constexpr = 128
139
+
140
+ # Process prefix indices
141
+ for block_start in range(0, prefix_len, BLOCK_SIZE):
142
+ offs = block_start + tl.arange(0, BLOCK_SIZE)
143
+ mask = offs < prefix_len
144
+
145
+ src_idx = prefix_start + offs
146
+ dst_idx = unified_start + offs
147
+
148
+ vals = tl.load(prefix_kv_indices + src_idx, mask=mask, other=0)
149
+ tl.store(unified_kv_indices + dst_idx, vals, mask=mask)
150
+
151
+ # Process extend indices
152
+ for block_start in range(0, extend_len, BLOCK_SIZE):
153
+ offs = block_start + tl.arange(0, BLOCK_SIZE)
154
+ mask = offs < extend_len
155
+
156
+ src_idx = extend_start + offs
157
+ dst_idx = unified_start + prefix_len + offs
158
+
159
+ vals = tl.load(extend_kv_indices + src_idx, mask=mask, other=0)
160
+ tl.store(unified_kv_indices + dst_idx, vals, mask=mask)
161
+
162
+
163
+ def build_unified_kv_indices(
164
+ prefix_kv_indptr: torch.Tensor,
165
+ prefix_kv_indices: torch.Tensor,
166
+ extend_start_loc: torch.Tensor,
167
+ extend_seq_lens: torch.Tensor,
168
+ extend_kv_indices: torch.Tensor,
169
+ bs: int,
170
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ """
172
+ Build unified KV indices efficiently:
173
+ - Use PyTorch's optimized cumsum (NVIDIA CUB) for indptr
174
+ - Use Triton kernel for parallel index copying
175
+
176
+ Returns:
177
+ (unified_kv_indptr, unified_kv_indices, prefix_lens)
178
+ """
179
+ device = prefix_kv_indptr.device
180
+
181
+ prefix_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]
182
+
183
+ # Create unified_kv_indptr avoiding direct assignment (for CUDA graph compatibility)
184
+ unified_lens = prefix_lens + extend_seq_lens[:bs]
185
+ unified_kv_indptr = torch.cat(
186
+ [
187
+ torch.zeros(1, dtype=torch.int32, device=device),
188
+ torch.cumsum(unified_lens, dim=0),
189
+ ]
190
+ )
191
+
192
+ max_unified_len = len(prefix_kv_indices) + len(extend_kv_indices)
193
+
194
+ unified_kv_indices = torch.empty(max_unified_len, dtype=torch.int64, device=device)
195
+
196
+ # Launch Triton kernel for parallel index copying
197
+ _copy_unified_indices_kernel[(bs,)](
198
+ prefix_kv_indptr,
199
+ prefix_kv_indices,
200
+ extend_start_loc,
201
+ extend_seq_lens,
202
+ extend_kv_indices,
203
+ unified_kv_indptr,
204
+ unified_kv_indices,
205
+ bs,
206
+ )
207
+
208
+ return unified_kv_indptr, unified_kv_indices, prefix_lens
209
+
210
+
41
211
  @triton.jit
42
212
  def _fwd_kernel(
43
213
  Q_Extend,
@@ -402,50 +572,10 @@ def extend_attention_fwd(
402
572
  v_extend.shape[-1],
403
573
  )
404
574
 
405
- if Lq == 576:
406
- BLOCK_DMODEL = 512
407
- BLOCK_DPE = 64
408
- elif Lq == 288:
409
- BLOCK_DMODEL = 256
410
- BLOCK_DPE = 32
411
- elif Lq == 192:
412
- BLOCK_DMODEL = 128
413
- BLOCK_DPE = 64
414
- else:
415
- BLOCK_DMODEL = triton.next_power_of_2(Lq)
416
- BLOCK_DPE = 0
417
- BLOCK_DV = triton.next_power_of_2(Lv)
418
-
419
- if _is_hip:
420
- BLOCK_M, BLOCK_N = (64, 64)
421
- num_warps = 4
422
-
423
- else:
424
- if _is_cuda and CUDA_CAPABILITY[0] >= 9:
425
- if Lq <= 256:
426
- BLOCK_M, BLOCK_N = (128, 64)
427
- else:
428
- BLOCK_M, BLOCK_N = (32, 64)
429
- elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
430
- # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
431
- if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
432
- if Lq <= 128:
433
- BLOCK_M, BLOCK_N = (64, 128)
434
- elif Lq <= 256:
435
- BLOCK_M, BLOCK_N = (64, 64)
436
- else:
437
- BLOCK_M, BLOCK_N = (32, 32)
438
- else:
439
- if Lq <= 128:
440
- BLOCK_M, BLOCK_N = (128, 128)
441
- elif Lq <= 256:
442
- BLOCK_M, BLOCK_N = (64, 64)
443
- else:
444
- BLOCK_M, BLOCK_N = (32, 64)
445
- else:
446
- BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
447
-
448
- num_warps = 4 if Lk <= 64 else 8
575
+ # Get block sizes and configuration
576
+ BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
577
+ _get_block_sizes_for_extend_attention(Lq, Lv)
578
+ )
449
579
 
450
580
  sm_scale = sm_scale or 1.0 / (Lq**0.5)
451
581
  batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
@@ -548,3 +678,368 @@ def redundant_attention(
548
678
  pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
549
679
  o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
550
680
  pt += cur_seq_len_extend
681
+
682
+
683
+ @triton.jit
684
+ def _fwd_kernel_unified(
685
+ Q,
686
+ O,
687
+ K_Buffer,
688
+ V_Buffer,
689
+ qo_indptr,
690
+ kv_indptr,
691
+ kv_indices,
692
+ prefix_lens,
693
+ mask_ptr,
694
+ mask_indptr,
695
+ sink_ptr,
696
+ window_start_pos,
697
+ sm_scale,
698
+ kv_group_num,
699
+ stride_qbs,
700
+ stride_qh,
701
+ stride_obs,
702
+ stride_oh,
703
+ stride_buf_kbs,
704
+ stride_buf_kh,
705
+ stride_buf_vbs,
706
+ stride_buf_vh,
707
+ SLIDING_WINDOW_SIZE: tl.constexpr,
708
+ logit_cap: tl.constexpr,
709
+ xai_temperature_len: tl.constexpr,
710
+ Lq: tl.constexpr,
711
+ Lv: tl.constexpr,
712
+ BLOCK_DMODEL: tl.constexpr,
713
+ BLOCK_DPE: tl.constexpr,
714
+ BLOCK_DV: tl.constexpr,
715
+ BLOCK_M: tl.constexpr,
716
+ BLOCK_N: tl.constexpr,
717
+ IS_CAUSAL: tl.constexpr,
718
+ USE_CUSTOM_MASK: tl.constexpr,
719
+ HAS_SINK: tl.constexpr,
720
+ ):
721
+ """
722
+ Unified 1-stage kernel for deterministic extend attention.
723
+ Both prefix and extend KV are accessed through the unified kv_indices.
724
+ """
725
+ cur_seq = tl.program_id(0)
726
+ cur_head = tl.program_id(1)
727
+ cur_block_m = tl.program_id(2)
728
+ cur_kv_head = cur_head // kv_group_num
729
+
730
+ # Load sequence information
731
+ cur_seq_q_start_idx = tl.load(qo_indptr + cur_seq)
732
+ cur_seq_q_len = tl.load(qo_indptr + cur_seq + 1) - cur_seq_q_start_idx
733
+ cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
734
+ cur_seq_kv_len = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
735
+ cur_seq_prefix_len = tl.load(prefix_lens + cur_seq)
736
+
737
+ # Load window start position for sliding window attention
738
+ # This is the absolute position of the first key in the window (0 if no sliding window)
739
+ cur_window_start = 0
740
+ if SLIDING_WINDOW_SIZE > 0:
741
+ cur_window_start = tl.load(window_start_pos + cur_seq)
742
+
743
+ # Load custom mask start index if using custom mask (for speculative decoding)
744
+ if USE_CUSTOM_MASK:
745
+ cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
746
+
747
+ offs_d = tl.arange(0, BLOCK_DMODEL)
748
+ offs_dv = tl.arange(0, BLOCK_DV)
749
+ offs_m = tl.arange(0, BLOCK_M)
750
+ mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_q_len
751
+ mask_d = offs_d < Lq
752
+ mask_dv = offs_dv < Lv
753
+
754
+ # XAI temperature handling
755
+ if xai_temperature_len > 0:
756
+ offs_qidx = cur_seq_prefix_len + cur_block_m * BLOCK_M + offs_m
757
+ xai_temperature_reg = tl.where(
758
+ offs_qidx < xai_temperature_len,
759
+ 1.0,
760
+ xai_temperature_len / (offs_qidx + 1.0),
761
+ )
762
+
763
+ # Load Q
764
+ offs_q = (
765
+ (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs
766
+ + cur_head * stride_qh
767
+ + offs_d[None, :]
768
+ )
769
+ q = tl.load(Q + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0)
770
+
771
+ if BLOCK_DPE > 0:
772
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
773
+ offs_qpe = (
774
+ (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs
775
+ + cur_head * stride_qh
776
+ + offs_dpe[None, :]
777
+ )
778
+ qpe = tl.load(Q + offs_qpe, mask=mask_m[:, None], other=0.0)
779
+
780
+ # Initialize accumulators
781
+ offs_n = tl.arange(0, BLOCK_N)
782
+ acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
783
+ deno = tl.zeros([BLOCK_M], dtype=tl.float32)
784
+ e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
785
+
786
+ # Unified loop: process all KV tokens (prefix + extend)
787
+ for start_n in range(0, cur_seq_kv_len, BLOCK_N):
788
+ start_n = tl.multiple_of(start_n, BLOCK_N)
789
+ mask_n = (start_n + offs_n) < cur_seq_kv_len
790
+
791
+ # Compute mask
792
+ final_mask = mask_m[:, None] & mask_n[None, :]
793
+
794
+ # Apply custom mask if provided
795
+ if USE_CUSTOM_MASK:
796
+ custom_mask = tl.load(
797
+ mask_ptr
798
+ + cur_seq_mask_start_idx
799
+ + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_kv_len
800
+ + start_n
801
+ + offs_n[None, :],
802
+ mask=(mask_m[:, None] & mask_n[None, :]),
803
+ other=0,
804
+ )
805
+ final_mask &= custom_mask
806
+
807
+ # Apply causal mask for extend part
808
+ if IS_CAUSAL and not USE_CUSTOM_MASK:
809
+ # Determine if current KV block is in extend region
810
+ # Only apply causal mask when both Q and K are in extend region
811
+ q_idx = cur_block_m * BLOCK_M + offs_m[:, None]
812
+ k_idx_in_total = start_n + offs_n[None, :]
813
+
814
+ # Causal mask: q_idx >= (k_idx - prefix_len) when k_idx >= prefix_len
815
+ # For prefix region (k_idx < prefix_len), no causal mask
816
+ k_is_extend = k_idx_in_total >= cur_seq_prefix_len
817
+ k_idx_in_extend = k_idx_in_total - cur_seq_prefix_len
818
+ causal_mask = tl.where(
819
+ k_is_extend,
820
+ q_idx >= k_idx_in_extend,
821
+ True, # No causal mask for prefix
822
+ )
823
+ final_mask &= causal_mask
824
+
825
+ if SLIDING_WINDOW_SIZE > 0:
826
+ # Sliding window mask with correct absolute positions
827
+ # Q absolute position: window_start + prefix_len + q_position_in_extend
828
+ q_abs_pos = (
829
+ cur_window_start
830
+ + cur_seq_prefix_len
831
+ + cur_block_m * BLOCK_M
832
+ + offs_m[:, None]
833
+ )
834
+
835
+ # K absolute position: window_start + k_index_in_unified_array
836
+ k_abs_pos = cur_window_start + start_n + offs_n[None, :]
837
+
838
+ # Sliding window: query can attend to keys within window_size
839
+ window_mask = q_abs_pos <= (k_abs_pos + SLIDING_WINDOW_SIZE)
840
+ final_mask &= window_mask
841
+
842
+ # Check if we can skip this tile
843
+ SKIP_TILE = False
844
+ if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
845
+ SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
846
+
847
+ if not SKIP_TILE:
848
+ # Load KV indices
849
+ offs_kv_loc = tl.load(
850
+ kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
851
+ mask=mask_n,
852
+ other=0,
853
+ )
854
+
855
+ # Load K
856
+ offs_buf_k = (
857
+ offs_kv_loc[None, :] * stride_buf_kbs
858
+ + cur_kv_head * stride_buf_kh
859
+ + offs_d[:, None]
860
+ )
861
+ k = tl.load(
862
+ K_Buffer + offs_buf_k,
863
+ mask=(mask_n[None, :]) & (mask_d[:, None]),
864
+ other=0.0,
865
+ )
866
+
867
+ # Compute QK
868
+ qk = tl.dot(q.to(k.dtype), k)
869
+ if BLOCK_DPE > 0:
870
+ offs_kpe = (
871
+ offs_kv_loc[None, :] * stride_buf_kbs
872
+ + cur_kv_head * stride_buf_kh
873
+ + offs_dpe[:, None]
874
+ )
875
+ kpe = tl.load(
876
+ K_Buffer + offs_kpe,
877
+ mask=mask_n[None, :],
878
+ other=0.0,
879
+ )
880
+ qk += tl.dot(qpe.to(kpe.dtype), kpe)
881
+
882
+ qk *= sm_scale
883
+
884
+ if logit_cap > 0:
885
+ qk = logit_cap * tanh(qk / logit_cap)
886
+
887
+ if xai_temperature_len > 0:
888
+ qk *= xai_temperature_reg[:, None]
889
+
890
+ qk = tl.where(final_mask, qk, float("-inf"))
891
+
892
+ # Online softmax
893
+ row_max = tl.max(qk, 1)
894
+ row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
895
+ n_e_max = tl.maximum(row_max_fixed, e_max)
896
+
897
+ re_scale = tl.exp(e_max - n_e_max)
898
+ p = tl.exp(qk - n_e_max[:, None])
899
+ deno = deno * re_scale + tl.sum(p, 1)
900
+
901
+ # Load V
902
+ offs_buf_v = (
903
+ offs_kv_loc[:, None] * stride_buf_vbs
904
+ + cur_kv_head * stride_buf_vh
905
+ + offs_dv[None, :]
906
+ )
907
+ v = tl.load(
908
+ V_Buffer + offs_buf_v,
909
+ mask=mask_n[:, None] & mask_dv[None, :],
910
+ other=0.0,
911
+ )
912
+ p = p.to(v.dtype)
913
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
914
+
915
+ e_max = n_e_max
916
+
917
+ # Handle sink tokens
918
+ if HAS_SINK:
919
+ cur_sink = tl.load(sink_ptr + cur_head)
920
+ deno += tl.exp(cur_sink - e_max)
921
+
922
+ # Store output
923
+ offs_o = (
924
+ (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs
925
+ + cur_head * stride_oh
926
+ + offs_dv[None, :]
927
+ )
928
+ tl.store(
929
+ O + offs_o,
930
+ acc / deno[:, None],
931
+ mask=mask_m[:, None] & mask_dv[None, :],
932
+ )
933
+
934
+
935
+ def extend_attention_fwd_unified(
936
+ q,
937
+ o,
938
+ k_buffer,
939
+ v_buffer,
940
+ qo_indptr,
941
+ kv_indptr,
942
+ kv_indices,
943
+ prefix_lens,
944
+ max_len_extend,
945
+ custom_mask=None,
946
+ mask_indptr=None,
947
+ sm_scale=None,
948
+ logit_cap=0.0,
949
+ is_causal=True,
950
+ sliding_window_size=-1,
951
+ sinks=None,
952
+ window_start_pos=None,
953
+ xai_temperature_len=-1,
954
+ ):
955
+ """
956
+ Unified 1-stage extend attention for deterministic inference.
957
+
958
+ Args:
959
+ q: Query tensor [num_tokens, num_heads, head_dim]
960
+ o: Output tensor [num_tokens, num_heads, head_dim]
961
+ k_buffer: Key cache buffer
962
+ v_buffer: Value cache buffer
963
+ qo_indptr: Query offsets [batch_size + 1]
964
+ kv_indptr: KV offsets [batch_size + 1] (includes both prefix and extend)
965
+ kv_indices: Unified KV indices (both prefix and extend)
966
+ prefix_lens: Prefix length for each sequence [batch_size]
967
+ max_len_extend: Maximum extend length
968
+ custom_mask: Custom attention mask (for speculative decoding tree attention)
969
+ mask_indptr: Mask offsets [batch_size + 1]
970
+ sm_scale: Softmax scale
971
+ logit_cap: Logit capping value
972
+ is_causal: Whether to apply causal mask
973
+ sliding_window_size: Sliding window size (-1 for no sliding window)
974
+ sinks: Sink tokens
975
+ window_start_pos: Absolute position of first key in sliding window [batch_size]
976
+ (None if sliding window not used)
977
+ xai_temperature_len: XAI temperature length
978
+ """
979
+ Lq, Lv = q.shape[-1], v_buffer.shape[-1]
980
+
981
+ # Get block sizes and configuration
982
+ BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
983
+ _get_block_sizes_for_extend_attention(Lq, Lv)
984
+ )
985
+
986
+ sm_scale = sm_scale or 1.0 / (Lq**0.5)
987
+ batch_size, head_num = qo_indptr.shape[0] - 1, q.shape[1]
988
+ kv_group_num = q.shape[1] // k_buffer.shape[1]
989
+
990
+ USE_CUSTOM_MASK = custom_mask is not None
991
+ HAS_SINK = sinks is not None
992
+
993
+ # For sliding window attention, window_start_pos tracks the absolute position
994
+ # of the first key in each sequence's window
995
+ if sliding_window_size > 0 and window_start_pos is None:
996
+ # If not provided, assume window starts at position 0
997
+ window_start_pos = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
998
+
999
+ grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
1000
+ num_stages = 1
1001
+
1002
+ extra_kargs = {}
1003
+ if _is_hip:
1004
+ extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
1005
+
1006
+ _fwd_kernel_unified[grid](
1007
+ q,
1008
+ o,
1009
+ k_buffer,
1010
+ v_buffer,
1011
+ qo_indptr,
1012
+ kv_indptr,
1013
+ kv_indices,
1014
+ prefix_lens,
1015
+ custom_mask,
1016
+ mask_indptr,
1017
+ sinks,
1018
+ window_start_pos,
1019
+ sm_scale,
1020
+ kv_group_num,
1021
+ q.stride(0),
1022
+ q.stride(1),
1023
+ o.stride(0),
1024
+ o.stride(1),
1025
+ k_buffer.stride(0),
1026
+ k_buffer.stride(1),
1027
+ v_buffer.stride(0),
1028
+ v_buffer.stride(1),
1029
+ SLIDING_WINDOW_SIZE=sliding_window_size,
1030
+ logit_cap=logit_cap,
1031
+ xai_temperature_len=xai_temperature_len,
1032
+ BLOCK_DMODEL=BLOCK_DMODEL,
1033
+ BLOCK_DPE=BLOCK_DPE,
1034
+ BLOCK_DV=BLOCK_DV,
1035
+ BLOCK_M=BLOCK_M,
1036
+ BLOCK_N=BLOCK_N,
1037
+ Lq=Lq,
1038
+ Lv=Lv,
1039
+ IS_CAUSAL=is_causal,
1040
+ USE_CUSTOM_MASK=USE_CUSTOM_MASK,
1041
+ HAS_SINK=HAS_SINK,
1042
+ num_warps=num_warps,
1043
+ num_stages=num_stages,
1044
+ **extra_kargs,
1045
+ )
@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
20
20
  if is_flashinfer_available():
21
21
  import flashinfer
22
22
 
23
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
24
-
25
23
  if TYPE_CHECKING:
26
24
  from sglang.srt.layers.radix_attention import RadixAttention
27
25
  from sglang.srt.model_executor.model_runner import ModelRunner
28
- from sglang.srt.speculative.spec_info import SpecInfo
26
+ from sglang.srt.speculative.spec_info import SpecInput
29
27
 
30
28
  # Constants
31
29
  DEFAULT_WORKSPACE_SIZE_MB = (
@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
201
199
  seq_lens: torch.Tensor,
202
200
  encoder_lens: Optional[torch.Tensor],
203
201
  forward_mode: ForwardMode,
204
- spec_info: Optional[SpecInfo],
202
+ spec_info: Optional[SpecInput],
205
203
  ):
206
204
  """Initialize metadata for CUDA graph capture."""
207
205
  metadata = TRTLLMMHAMetadata()
@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
314
312
  seq_lens_sum: int,
315
313
  encoder_lens: Optional[torch.Tensor],
316
314
  forward_mode: ForwardMode,
317
- spec_info: Optional[SpecInfo],
315
+ spec_info: Optional[SpecInput],
318
316
  seq_lens_cpu: Optional[torch.Tensor],
319
317
  ):
320
318
  """Replay CUDA graph with new inputs."""
@@ -639,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
639
637
  self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
640
638
  ):
641
639
  super().__init__(model_runner, topk, speculative_num_steps)
642
- for i in range(speculative_num_steps):
640
+ for i in range(self.speculative_num_steps - 1):
643
641
  self.attn_backends[i] = TRTLLMHAAttnBackend(
644
642
  model_runner,
645
643
  skip_prefill=True,
@@ -653,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
653
651
  self.attn_backends[i].init_forward_metadata(forward_batch)
654
652
 
655
653
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
656
- for i in range(self.speculative_num_steps):
654
+ for i in range(self.speculative_num_steps - 1):
657
655
  self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
658
656
 
659
657
  def init_forward_metadata_capture_cuda_graph(
@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
661
659
  forward_batch: ForwardBatch,
662
660
  ):
663
661
  assert forward_batch.spec_info is not None
664
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
662
+ assert forward_batch.spec_info.is_draft_input()
665
663
 
666
664
  for i in range(self.speculative_num_steps - 1):
667
665
  self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
678
676
  self, forward_batch: ForwardBatch, bs: int
679
677
  ):
680
678
  assert forward_batch.spec_info is not None
681
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
679
+ assert forward_batch.spec_info.is_draft_input()
682
680
 
683
681
  for i in range(self.speculative_num_steps - 1):
684
682