sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  import torch
8
+ import triton
9
+ import triton.language as tl
8
10
 
9
11
  from sglang.srt.configs.model_config import AttentionArch
10
12
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -64,6 +66,9 @@ class FlashAttentionMetadata:
64
66
 
65
67
  local_attn_metadata: Optional[LocalAttentionMetadata] = None
66
68
 
69
+ # For sliding window attention topk>1 spec decoding
70
+ swa_spec_metadata: Optional[FlashAttentionMetadata] = None
71
+
67
72
 
68
73
  # Copied from:
69
74
  # https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
@@ -340,6 +345,13 @@ class FlashAttentionBackend(AttentionBackend):
340
345
  else None
341
346
  )
342
347
 
348
+ # For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
349
+ # We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
350
+ self.sliding_window_size = model_runner.sliding_window_size
351
+ self.has_swa = (
352
+ self.sliding_window_size is not None and self.sliding_window_size > -1
353
+ )
354
+
343
355
  def init_forward_metadata(self, forward_batch: ForwardBatch):
344
356
  """Initialize forward metadata hence all layers in the forward pass can reuse it."""
345
357
  metadata = FlashAttentionMetadata()
@@ -556,6 +568,12 @@ class FlashAttentionBackend(AttentionBackend):
556
568
  (1, 0),
557
569
  )
558
570
  self.forward_metadata_spec_decode_expand = metadata_expand
571
+
572
+ if self.has_swa:
573
+ self._init_sliding_window_attn_spec_metadata(
574
+ metadata, metadata_expand
575
+ )
576
+
559
577
  elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
560
578
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
561
579
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
@@ -657,11 +675,10 @@ class FlashAttentionBackend(AttentionBackend):
657
675
  # Calculate window size (can be moved to metadata if layer properties don't change)
658
676
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
659
677
  # here is two side inclusive
660
- window_size = (
661
- (layer.sliding_window_size, 0)
662
- if layer.sliding_window_size is not None and layer.sliding_window_size > -1
663
- else (-1, -1)
678
+ is_swa = (
679
+ layer.sliding_window_size is not None and layer.sliding_window_size > -1
664
680
  )
681
+ window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
665
682
  k_descale, v_descale = None, None
666
683
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
667
684
  # has corresponding quantization method so that layer.k_scale is not None,
@@ -684,8 +701,13 @@ class FlashAttentionBackend(AttentionBackend):
684
701
  )
685
702
 
686
703
  # We do cascade attention for Target Verify with topk > 1
704
+ # We don't use cascade attention for Sliding Window Attention:
705
+ # - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
706
+ # - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
687
707
  use_cascade_attn = (
688
- forward_batch.forward_mode.is_target_verify() and self.topk > 1
708
+ forward_batch.forward_mode.is_target_verify()
709
+ and self.topk > 1
710
+ and not is_swa
689
711
  )
690
712
 
691
713
  # For fa3 interface version compatibility, we put new fields into conditional keyword args
@@ -700,13 +722,18 @@ class FlashAttentionBackend(AttentionBackend):
700
722
  cu_seqlens_q = local_metadata.local_query_start_loc
701
723
  cache_seqlens = local_metadata.local_seqused_k
702
724
  max_seqlen_q = local_metadata.local_max_query_len
703
- max_seqlen_k = local_metadata.local_max_seq_len
725
+ elif is_swa and metadata.swa_spec_metadata is not None:
726
+ swa_spec_metadata = metadata.swa_spec_metadata
727
+ page_table = swa_spec_metadata.page_table
728
+ cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
729
+ cache_seqlens = swa_spec_metadata.cache_seqlens_int32
730
+ max_seqlen_q = swa_spec_metadata.max_seq_len_q
731
+ cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
704
732
  else:
705
733
  page_table = metadata.page_table
706
734
  cu_seqlens_q = metadata.cu_seqlens_q
707
735
  cache_seqlens = metadata.cache_seqlens_int32
708
736
  max_seqlen_q = metadata.max_seq_len_q
709
- max_seqlen_k = metadata.max_seq_len_k
710
737
  cu_seqlens_k = metadata.cu_seqlens_k
711
738
 
712
739
  # Use Flash Attention for prefill
@@ -1377,6 +1404,32 @@ class FlashAttentionBackend(AttentionBackend):
1377
1404
  ),
1378
1405
  }
1379
1406
 
1407
+ if self.has_swa:
1408
+ self.target_verify_metadata_topk_swa = {
1409
+ "cache_seqlens": torch.zeros(
1410
+ max_bs * self.speculative_num_draft_tokens,
1411
+ dtype=torch.int32,
1412
+ device=self.device,
1413
+ ),
1414
+ "cu_seqlens_k": torch.zeros(
1415
+ max_bs * self.speculative_num_draft_tokens + 1,
1416
+ dtype=torch.int32,
1417
+ device=self.device,
1418
+ ),
1419
+ "cu_seqlens_q": torch.arange(
1420
+ 0,
1421
+ max_bs * self.speculative_num_draft_tokens + 1,
1422
+ dtype=torch.int32,
1423
+ device=self.device,
1424
+ ),
1425
+ "page_table": torch.zeros(
1426
+ max_bs * self.speculative_num_draft_tokens,
1427
+ self.max_context_len,
1428
+ dtype=torch.int32,
1429
+ device=self.device,
1430
+ ),
1431
+ }
1432
+
1380
1433
  self.encoder_metadata = {
1381
1434
  "encoder_page_table": torch.zeros(
1382
1435
  max_bs,
@@ -1564,6 +1617,28 @@ class FlashAttentionBackend(AttentionBackend):
1564
1617
 
1565
1618
  self.target_verify_metadata_topk_normal[bs] = metadata
1566
1619
  self.target_verify_metadata_topk_expand[bs] = metadata_expand
1620
+
1621
+ if self.has_swa:
1622
+ metadata_swa = FlashAttentionMetadata()
1623
+ metadata_swa.cache_seqlens_int32 = (
1624
+ self.target_verify_metadata_topk_swa["cache_seqlens"][
1625
+ : bs * self.speculative_num_draft_tokens
1626
+ ]
1627
+ )
1628
+ metadata_swa.max_seq_len_q = 1
1629
+ metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[
1630
+ "cu_seqlens_q"
1631
+ ][: bs * self.speculative_num_draft_tokens + 1]
1632
+ metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[
1633
+ "cu_seqlens_k"
1634
+ ][: bs * self.speculative_num_draft_tokens + 1]
1635
+
1636
+ metadata_swa.page_table = self.target_verify_metadata_topk_swa[
1637
+ "page_table"
1638
+ ][: bs * self.speculative_num_draft_tokens]
1639
+ self.target_verify_metadata_topk_swa[bs] = metadata_swa
1640
+ metadata.swa_spec_metadata = metadata_swa
1641
+
1567
1642
  elif forward_mode.is_draft_extend():
1568
1643
  metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
1569
1644
  :bs
@@ -1804,6 +1879,12 @@ class FlashAttentionBackend(AttentionBackend):
1804
1879
  )
1805
1880
  )
1806
1881
 
1882
+ if self.has_swa:
1883
+ metadata_swa = self.target_verify_metadata_topk_swa[bs]
1884
+ self._init_sliding_window_attn_spec_metadata(
1885
+ metadata, metadata_expand, metadata_swa
1886
+ )
1887
+
1807
1888
  elif forward_mode.is_draft_extend():
1808
1889
  metadata = self.draft_extend_metadata[bs]
1809
1890
  metadata.cache_seqlens_int32.copy_(seq_lens)
@@ -2039,6 +2120,159 @@ class FlashAttentionBackend(AttentionBackend):
2039
2120
  lam.local_max_query_len = int(seqlens_q_local_np.max())
2040
2121
  lam.local_max_seq_len = int(seqlens_k_local_np.max())
2041
2122
 
2123
+ def _init_sliding_window_attn_spec_metadata(
2124
+ self,
2125
+ metadata: FlashAttentionMetadata,
2126
+ metadata_expand: FlashAttentionMetadata,
2127
+ metadata_swa: Optional[FlashAttentionMetadata] = None,
2128
+ ):
2129
+ # TODO: support page_size > 1 for swa spec
2130
+ assert (
2131
+ self.page_size == 1
2132
+ ), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
2133
+
2134
+ cache_seqlens_int32 = (
2135
+ metadata.cache_seqlens_int32.repeat_interleave(
2136
+ self.speculative_num_draft_tokens
2137
+ )
2138
+ + metadata_expand.cache_seqlens_int32
2139
+ )
2140
+ cu_seqlens_k = torch.nn.functional.pad(
2141
+ torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
2142
+ )
2143
+ bs = cache_seqlens_int32.shape[0]
2144
+ page_table = (
2145
+ metadata.page_table.new_zeros(
2146
+ (bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
2147
+ )
2148
+ if metadata_swa is None
2149
+ else metadata_swa.page_table
2150
+ )
2151
+
2152
+ prepare_swa_spec_page_table_triton(
2153
+ page_table,
2154
+ metadata.page_table,
2155
+ metadata_expand.page_table,
2156
+ metadata.cache_seqlens_int32,
2157
+ metadata_expand.cache_seqlens_int32,
2158
+ self.speculative_num_draft_tokens,
2159
+ )
2160
+
2161
+ if metadata_swa is None:
2162
+ metadata_swa = FlashAttentionMetadata()
2163
+ metadata_swa.max_seq_len_q = 1
2164
+ metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
2165
+ metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
2166
+ metadata_swa.cu_seqlens_k = cu_seqlens_k
2167
+ metadata_swa.page_table = page_table
2168
+ else:
2169
+ metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
2170
+ metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
2171
+
2172
+ metadata.swa_spec_metadata = metadata_swa
2173
+
2174
+
2175
+ @triton.jit
2176
+ def _prepare_swa_spec_page_table_kernel(
2177
+ dst_ptr,
2178
+ src_a_ptr,
2179
+ src_b_ptr,
2180
+ seq_len_a_ptr,
2181
+ seq_len_b_ptr,
2182
+ dst_stride_m,
2183
+ dst_stride_n,
2184
+ a_stride_m,
2185
+ a_stride_n,
2186
+ b_stride_m,
2187
+ b_stride_n,
2188
+ LEN_A: tl.constexpr,
2189
+ LEN_B: tl.constexpr,
2190
+ REPEAT_STEP: tl.constexpr,
2191
+ BLOCK_N: tl.constexpr,
2192
+ ):
2193
+ pid_m = tl.program_id(0)
2194
+ pid_n = tl.program_id(1)
2195
+
2196
+ idx_a = pid_m // REPEAT_STEP
2197
+ idx_b = pid_m
2198
+ seq_len_a = tl.load(seq_len_a_ptr + idx_a)
2199
+ seq_len_b = tl.load(seq_len_b_ptr + idx_b)
2200
+
2201
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
2202
+ total_len = seq_len_a + seq_len_b
2203
+
2204
+ if pid_n * BLOCK_N >= total_len:
2205
+ return
2206
+
2207
+ mask = offs_n < total_len
2208
+ dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n
2209
+
2210
+ if (pid_n + 1) * BLOCK_N < seq_len_a:
2211
+ a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n
2212
+ a_mask = mask & (offs_n < LEN_A)
2213
+ val = tl.load(a_ptr, mask=a_mask, other=0)
2214
+ tl.store(dst, val, mask=mask)
2215
+ elif pid_n * BLOCK_N >= seq_len_a:
2216
+ offs_b = offs_n - seq_len_a
2217
+ b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n
2218
+ b_mask = mask & (offs_b < LEN_B)
2219
+ val = tl.load(b_ptr, mask=b_mask, other=0)
2220
+ tl.store(dst, val, mask=mask)
2221
+ else:
2222
+ # mixed part
2223
+ a_offs = offs_n
2224
+ a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A)
2225
+ a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n
2226
+ a_val = tl.load(a_ptr, mask=a_mask, other=0)
2227
+
2228
+ b_offs = offs_n - seq_len_a
2229
+ b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B)
2230
+ b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n
2231
+ b_val = tl.load(b_ptr, mask=b_mask, other=0)
2232
+
2233
+ result = tl.where(offs_n < seq_len_a, a_val, b_val)
2234
+ tl.store(dst, result, mask=mask)
2235
+
2236
+
2237
+ def prepare_swa_spec_page_table_triton(
2238
+ page_table_dst: torch.Tensor,
2239
+ page_table_a: torch.Tensor,
2240
+ page_table_b: torch.Tensor, # expand page table
2241
+ seq_len_a: torch.Tensor,
2242
+ seq_len_b: torch.Tensor, # expand seq lens
2243
+ speculative_num_draft_tokens: int,
2244
+ ):
2245
+ # concat page_table and expand page_table by kv seq length
2246
+ bs = seq_len_a.numel()
2247
+ bs_expand = seq_len_b.numel()
2248
+ assert bs_expand == bs * speculative_num_draft_tokens
2249
+
2250
+ LEN_A = page_table_a.shape[1]
2251
+ LEN_B = page_table_b.shape[1]
2252
+ LEN_OUT = LEN_A + LEN_B
2253
+ REPEAT_STEP = speculative_num_draft_tokens
2254
+ BLOCK_N = 256
2255
+
2256
+ grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N))
2257
+ _prepare_swa_spec_page_table_kernel[grid](
2258
+ page_table_dst,
2259
+ page_table_a,
2260
+ page_table_b,
2261
+ seq_len_a,
2262
+ seq_len_b,
2263
+ page_table_dst.stride(0),
2264
+ page_table_dst.stride(1),
2265
+ page_table_a.stride(0),
2266
+ page_table_a.stride(1),
2267
+ page_table_b.stride(0),
2268
+ page_table_b.stride(1),
2269
+ LEN_A=LEN_A,
2270
+ LEN_B=LEN_B,
2271
+ REPEAT_STEP=REPEAT_STEP,
2272
+ BLOCK_N=BLOCK_N,
2273
+ num_warps=4,
2274
+ )
2275
+
2042
2276
 
2043
2277
  class FlashAttentionMultiStepBackend:
2044
2278
 
@@ -26,11 +26,14 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
26
26
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
27
27
  from sglang.srt.layers.dp_attention import get_attention_tp_size
28
28
  from sglang.srt.layers.radix_attention import AttentionType
29
- from sglang.srt.layers.utils import is_sm100_supported
30
29
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
31
30
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
32
31
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
33
- from sglang.srt.utils import is_flashinfer_available, next_power_of_2
32
+ from sglang.srt.utils import (
33
+ is_flashinfer_available,
34
+ is_sm100_supported,
35
+ next_power_of_2,
36
+ )
34
37
 
35
38
  if TYPE_CHECKING:
36
39
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -498,8 +501,9 @@ class FlashInferAttnBackend(AttentionBackend):
498
501
  sm_scale=layer.scaling,
499
502
  window_left=layer.sliding_window_size,
500
503
  logits_soft_cap=logits_soft_cap,
501
- k_scale=layer.k_scale,
502
- v_scale=layer.v_scale,
504
+ # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
505
+ k_scale=layer.k_scale_float,
506
+ v_scale=layer.v_scale_float,
503
507
  )
504
508
  else:
505
509
  causal = True
@@ -577,8 +581,9 @@ class FlashInferAttnBackend(AttentionBackend):
577
581
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
578
582
  sm_scale=layer.scaling,
579
583
  logits_soft_cap=layer.logit_cap,
580
- k_scale=layer.k_scale,
581
- v_scale=layer.v_scale,
584
+ # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
585
+ k_scale=layer.k_scale_float,
586
+ v_scale=layer.v_scale_float,
582
587
  )
583
588
 
584
589
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -28,11 +28,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
28
28
  create_flashinfer_kv_indices_triton,
29
29
  )
30
30
  from sglang.srt.layers.dp_attention import get_attention_tp_size
31
- from sglang.srt.layers.utils import is_sm100_supported
32
31
  from sglang.srt.managers.schedule_batch import global_server_args_dict
33
32
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
34
33
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
35
- from sglang.srt.utils import is_flashinfer_available, next_power_of_2
34
+ from sglang.srt.utils import (
35
+ is_flashinfer_available,
36
+ is_sm100_supported,
37
+ next_power_of_2,
38
+ )
36
39
 
37
40
  if TYPE_CHECKING:
38
41
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -93,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
93
96
  def update_wrapper(
94
97
  self,
95
98
  forward_batch: ForwardBatch,
99
+ disable_flashinfer_ragged: bool = False,
96
100
  ):
97
101
  assert forward_batch.num_prefix_chunks is not None
98
102
  num_prefix_chunks = forward_batch.num_prefix_chunks
@@ -125,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
125
129
  causal=False,
126
130
  )
127
131
  # ragged prefill
128
- self.ragged_wrapper.begin_forward(
129
- qo_indptr=qo_indptr,
130
- kv_indptr=qo_indptr,
131
- num_qo_heads=self.num_local_heads,
132
- num_kv_heads=self.num_local_heads,
133
- head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
134
- head_dim_vo=self.v_head_dim,
135
- q_data_type=self.q_data_type,
136
- causal=True,
137
- )
132
+ if not disable_flashinfer_ragged:
133
+ self.ragged_wrapper.begin_forward(
134
+ qo_indptr=qo_indptr,
135
+ kv_indptr=qo_indptr,
136
+ num_qo_heads=self.num_local_heads,
137
+ num_kv_heads=self.num_local_heads,
138
+ head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
139
+ head_dim_vo=self.v_head_dim,
140
+ q_data_type=self.q_data_type,
141
+ causal=True,
142
+ )
138
143
 
139
144
  def forward(
140
145
  self,
@@ -488,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
488
493
  def get_cuda_graph_seq_len_fill_value(self):
489
494
  return 1
490
495
 
491
- def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
496
+ def init_mha_chunk_metadata(
497
+ self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False
498
+ ):
492
499
  """Init the metadata for a forward pass."""
493
- self.mha_chunk_kv_cache.update_wrapper(forward_batch)
500
+ self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
494
501
 
495
502
  def forward_extend(
496
503
  self,
@@ -5,6 +5,7 @@ import torch
5
5
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
6
6
  from sglang.srt.layers.radix_attention import RadixAttention
7
7
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
8
+ from sglang.srt.model_executor.model_runner import ModelRunner
8
9
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
9
10
 
10
11
 
@@ -12,19 +13,54 @@ class HybridAttnBackend(AttentionBackend):
12
13
  """Support different backends for prefill and decode."""
13
14
 
14
15
  def __init__(
15
- self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
16
+ self,
17
+ model_runner: ModelRunner,
18
+ prefill_backend: AttentionBackend,
19
+ decode_backend: AttentionBackend,
16
20
  ):
21
+ self.model_runner = model_runner
17
22
  self.prefill_backend = prefill_backend
18
23
  self.decode_backend = decode_backend
19
24
 
20
- def init_forward_metadata(self, forward_batch: ForwardBatch):
21
- if forward_batch.forward_mode.is_decode():
22
- self.decode_backend.init_forward_metadata(forward_batch)
25
+ def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
26
+ """
27
+ Select the appropriate attention backend based on the forward mode.
28
+
29
+ Args:
30
+ forward_mode: The current forward mode indicating the operation type
31
+
32
+ Returns:
33
+ The selected attention backend (prefill or decode)
34
+
35
+ Note:
36
+ - decode_or_idle: Always uses decode backend
37
+ - target_verify or draft_extend: Uses decode backend if speculative_attention_mode is "decode", otherwise prefill backend
38
+ - prefill: Always uses prefill backend
39
+ """
40
+ if forward_mode.is_decode_or_idle():
41
+ return self.decode_backend
42
+ elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
43
+ return (
44
+ self.decode_backend
45
+ if self.model_runner.server_args.speculative_attention_mode == "decode"
46
+ else self.prefill_backend
47
+ )
23
48
  else:
24
- self.prefill_backend.init_forward_metadata(forward_batch)
49
+ return self.prefill_backend
50
+
51
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
52
+ backend = self._select_backend(forward_batch.forward_mode)
53
+ backend.init_forward_metadata(forward_batch)
25
54
 
26
55
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
27
56
  self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
57
+ if (
58
+ self.model_runner.server_args.speculative_algorithm is not None
59
+ and self.model_runner.server_args.speculative_attention_mode == "prefill"
60
+ ):
61
+ # When speculative decoding is enabled, we need to initialize the backend
62
+ # that will be used for target_verify.
63
+ self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
28
64
 
29
65
  def init_forward_metadata_capture_cuda_graph(
30
66
  self,
@@ -36,7 +72,8 @@ class HybridAttnBackend(AttentionBackend):
36
72
  forward_mode: ForwardMode,
37
73
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
38
74
  ):
39
- self.decode_backend.init_forward_metadata_capture_cuda_graph(
75
+ backend = self._select_backend(forward_mode)
76
+ backend.init_forward_metadata_capture_cuda_graph(
40
77
  bs,
41
78
  num_tokens,
42
79
  req_pool_indices,
@@ -57,7 +94,8 @@ class HybridAttnBackend(AttentionBackend):
57
94
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
58
95
  seq_lens_cpu: Optional[torch.Tensor],
59
96
  ):
60
- self.decode_backend.init_forward_metadata_replay_cuda_graph(
97
+ backend = self._select_backend(forward_mode)
98
+ backend.init_forward_metadata_replay_cuda_graph(
61
99
  bs,
62
100
  req_pool_indices,
63
101
  seq_lens,
@@ -95,6 +133,7 @@ class HybridAttnBackend(AttentionBackend):
95
133
  save_kv_cache: bool = True,
96
134
  **kwargs,
97
135
  ):
98
- return self.prefill_backend.forward_extend(
136
+ backend = self._select_backend(forward_batch.forward_mode)
137
+ return backend.forward_extend(
99
138
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
100
139
  )