sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,582 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for flashinfer MLA.
5
+ The flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false.
6
+ When it's set to false, all wrappers are BatchMLAPaged wrapper.
7
+ When it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling,
8
+ and uses BatchMLAPaged wrapper for decoding.
9
+ More details can be found in https://docs.flashinfer.ai/api/mla.html
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from functools import partial
14
+ from typing import TYPE_CHECKING, Optional, Union
15
+
16
+ import torch
17
+
18
+ from sglang.global_config import global_config
19
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
+ from sglang.srt.layers.attention.flashinfer_backend import (
21
+ create_flashinfer_kv_indices_triton,
22
+ )
23
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
24
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
25
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
26
+ from sglang.srt.utils import is_flashinfer_available
27
+
28
+ if TYPE_CHECKING:
29
+ from sglang.srt.layers.radix_attention import RadixAttention
30
+ from sglang.srt.model_executor.model_runner import ModelRunner
31
+ from sglang.srt.speculative.spec_info import SpecInfo
32
+
33
+ if is_flashinfer_available():
34
+ from flashinfer import (
35
+ BatchMLAPagedAttentionWrapper,
36
+ BatchPrefillWithRaggedKVCacheWrapper,
37
+ )
38
+
39
+
40
+ @dataclass
41
+ class DecodeMetadata:
42
+ decode_wrapper: BatchMLAPagedAttentionWrapper
43
+
44
+
45
+ @dataclass
46
+ class PrefillMetadata:
47
+ prefill_wrapper: BatchMLAPagedAttentionWrapper
48
+ use_ragged: bool
49
+
50
+
51
+ # Reuse this workspace buffer across all flashinfer wrappers
52
+ global_workspace_buffer = None
53
+
54
+
55
+ class FlashInferMLAAttnBackend(AttentionBackend):
56
+ """Flashinfer attention kernels."""
57
+
58
+ def __init__(
59
+ self,
60
+ model_runner: ModelRunner,
61
+ ):
62
+ super().__init__()
63
+
64
+ # Parse constants
65
+ self.max_context_len = model_runner.model_config.context_len
66
+ self.device = model_runner.device
67
+
68
+ global_config.enable_flashinfer_mla = True
69
+
70
+ # Allocate buffers
71
+ global global_workspace_buffer
72
+ if global_workspace_buffer is None:
73
+ global_workspace_buffer = torch.empty(
74
+ global_config.flashinfer_workspace_size,
75
+ dtype=torch.uint8,
76
+ device=model_runner.device,
77
+ )
78
+ self.workspace_buffer = global_workspace_buffer
79
+
80
+ max_bs = model_runner.req_to_token_pool.size
81
+ self.kv_indptr = torch.zeros(
82
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
83
+ )
84
+
85
+ self.qo_indptr = torch.zeros(
86
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
87
+ )
88
+
89
+ self.q_indptr_decode = torch.arange(
90
+ 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
91
+ )
92
+
93
+ self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
94
+ self.workspace_buffer, "NHD"
95
+ )
96
+
97
+ self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
98
+ self.workspace_buffer,
99
+ backend="auto",
100
+ )
101
+
102
+ self.decode_wrapper = BatchMLAPagedAttentionWrapper(
103
+ self.workspace_buffer, backend="auto"
104
+ )
105
+
106
+ # Create indices updater
107
+ self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
108
+ model_runner, self
109
+ )
110
+ self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
111
+ model_runner, self
112
+ )
113
+
114
+ # Other metadata
115
+ self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
116
+ self.decode_cuda_graph_metadata = {}
117
+ self.prefill_cuda_graph_metadata = {}
118
+
119
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
120
+ if forward_batch.forward_mode.is_decode_or_idle():
121
+ self.indices_updater_decode.update(
122
+ forward_batch.req_pool_indices,
123
+ forward_batch.seq_lens,
124
+ forward_batch.seq_lens_sum,
125
+ decode_wrapper=self.decode_wrapper,
126
+ init_metadata_replay=False,
127
+ )
128
+ self.forward_metadata = DecodeMetadata(self.decode_wrapper)
129
+ else:
130
+ prefix_lens = forward_batch.extend_prefix_lens
131
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
132
+ use_ragged = (
133
+ not global_server_args_dict["flashinfer_mla_disable_ragged"]
134
+ and extend_no_prefix
135
+ )
136
+
137
+ self.indices_updater_prefill.update(
138
+ forward_batch.req_pool_indices,
139
+ forward_batch.seq_lens,
140
+ forward_batch.seq_lens_sum,
141
+ prefix_lens,
142
+ prefill_wrapper_paged=self.prefill_wrapper_paged,
143
+ use_ragged=use_ragged,
144
+ )
145
+ self.forward_metadata = PrefillMetadata(
146
+ self.prefill_wrapper_paged, use_ragged
147
+ )
148
+
149
+ def init_cuda_graph_state(
150
+ self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
151
+ ):
152
+ if kv_indices_buf is None:
153
+ cuda_graph_kv_indices = torch.zeros(
154
+ (max_bs * self.max_context_len,),
155
+ dtype=torch.int32,
156
+ device="cuda",
157
+ )
158
+ else:
159
+ cuda_graph_kv_indices = kv_indices_buf
160
+
161
+ self.cuda_graph_kv_indices = cuda_graph_kv_indices
162
+ self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
163
+ self.cuda_graph_kv_indptr = self.kv_indptr.clone()
164
+ self.cuda_graph_kv_lens = torch.ones(
165
+ (max_bs,), dtype=torch.int32, device=self.device
166
+ )
167
+
168
+ # For fast decode plan in graph replaying
169
+ self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
170
+ self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
171
+ self.fast_decode_kwargs = {
172
+ "qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
173
+ "kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
174
+ "kv_indices": self.cuda_graph_kv_indices,
175
+ }
176
+
177
+ def init_forward_metadata_capture_cuda_graph(
178
+ self,
179
+ bs: int,
180
+ num_tokens: int,
181
+ req_pool_indices: torch.Tensor,
182
+ seq_lens: torch.Tensor,
183
+ encoder_lens: Optional[torch.Tensor],
184
+ forward_mode: ForwardMode,
185
+ spec_info: Optional[SpecInfo],
186
+ ):
187
+ if forward_mode.is_decode_or_idle():
188
+ decode_wrapper = BatchMLAPagedAttentionWrapper(
189
+ self.workspace_buffer,
190
+ use_cuda_graph=True,
191
+ qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
192
+ kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
193
+ kv_indices=self.cuda_graph_kv_indices,
194
+ kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
195
+ backend="auto",
196
+ )
197
+
198
+ seq_lens_sum = seq_lens.sum().item()
199
+ self.indices_updater_decode.update(
200
+ req_pool_indices,
201
+ seq_lens,
202
+ seq_lens_sum,
203
+ decode_wrapper=decode_wrapper,
204
+ init_metadata_replay=False,
205
+ )
206
+ self.decode_cuda_graph_metadata[bs] = decode_wrapper
207
+ self.forward_metadata = DecodeMetadata(decode_wrapper)
208
+ decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
209
+ else:
210
+ raise ValueError(f"Invalid mode: {forward_mode=}")
211
+
212
+ def init_forward_metadata_replay_cuda_graph(
213
+ self,
214
+ bs: int,
215
+ req_pool_indices: torch.Tensor,
216
+ seq_lens: torch.Tensor,
217
+ seq_lens_sum: int,
218
+ encoder_lens: Optional[torch.Tensor],
219
+ forward_mode: ForwardMode,
220
+ spec_info: Optional[SpecInfo],
221
+ seq_lens_cpu: Optional[torch.Tensor],
222
+ ):
223
+ if forward_mode.is_decode_or_idle():
224
+ kv_len_arr_cpu = seq_lens_cpu[:bs]
225
+ self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
226
+ kv_len_arr_cpu, dim=0
227
+ )
228
+ self.fast_decode_kwargs.update(
229
+ {
230
+ "qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
231
+ "kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
232
+ "kv_len_arr_cpu": kv_len_arr_cpu,
233
+ }
234
+ )
235
+
236
+ self.indices_updater_decode.update(
237
+ req_pool_indices[:bs],
238
+ seq_lens[:bs],
239
+ seq_lens_sum,
240
+ decode_wrapper=self.decode_cuda_graph_metadata[bs],
241
+ init_metadata_replay=True,
242
+ **self.fast_decode_kwargs,
243
+ )
244
+ else:
245
+ raise ValueError(f"Invalid forward mode: {forward_mode=}")
246
+
247
+ def get_cuda_graph_seq_len_fill_value(self):
248
+ return 0
249
+
250
+ def forward_extend(
251
+ self,
252
+ q: torch.Tensor,
253
+ k: torch.Tensor,
254
+ v: torch.Tensor,
255
+ layer: RadixAttention,
256
+ forward_batch: ForwardBatch,
257
+ save_kv_cache=True,
258
+ ):
259
+
260
+ cache_loc = forward_batch.out_cache_loc
261
+ logits_soft_cap = layer.logit_cap
262
+ prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
263
+ qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
264
+ k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
265
+
266
+ # Save kv cache
267
+ if save_kv_cache and k is not None:
268
+ assert v is not None
269
+ if save_kv_cache:
270
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
271
+
272
+ if self.forward_metadata.use_ragged:
273
+ # ragged prefill
274
+ o, _ = self.prefill_wrapper_ragged.forward_return_lse(
275
+ qall,
276
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
277
+ v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
278
+ causal=True,
279
+ sm_scale=layer.scaling,
280
+ logits_soft_cap=logits_soft_cap,
281
+ )
282
+ else:
283
+ # mla paged prefill
284
+ o = prefill_wrapper_paged.run(
285
+ qall[:, :, : layer.v_head_dim],
286
+ qall[:, :, layer.v_head_dim :],
287
+ k_buf[:, :, : layer.v_head_dim],
288
+ k_buf[:, :, layer.v_head_dim :],
289
+ )
290
+
291
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
292
+
293
+ def forward_decode(
294
+ self,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ layer: RadixAttention,
299
+ forward_batch: ForwardBatch,
300
+ save_kv_cache=True,
301
+ ):
302
+ decode_wrapper = self.forward_metadata.decode_wrapper
303
+ cache_loc = forward_batch.out_cache_loc
304
+
305
+ if k is not None:
306
+ assert v is not None
307
+ if save_kv_cache:
308
+ forward_batch.token_to_kv_pool.set_kv_buffer(
309
+ layer,
310
+ cache_loc,
311
+ k,
312
+ v,
313
+ )
314
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
315
+ k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
316
+ reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
317
+ o = decode_wrapper.run(
318
+ reshaped_q[:, :, : layer.v_head_dim],
319
+ reshaped_q[:, :, layer.v_head_dim :],
320
+ reshaped_k[:, :, : layer.v_head_dim],
321
+ reshaped_k[:, :, layer.v_head_dim :],
322
+ )
323
+
324
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
325
+
326
+
327
+ class FlashInferMLAIndicesUpdaterDecode:
328
+ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
329
+ # Parse Constants
330
+ self.num_local_heads = (
331
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
332
+ )
333
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
334
+ self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
335
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
336
+ self.scaling = model_runner.model_config.scaling
337
+ self.data_type = model_runner.kv_cache_dtype
338
+ self.attn_backend = attn_backend
339
+
340
+ # Buffers and wrappers
341
+ self.kv_indptr = attn_backend.kv_indptr
342
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
343
+ self.q_indptr = attn_backend.q_indptr_decode
344
+
345
+ def update(
346
+ self,
347
+ req_pool_indices: torch.Tensor,
348
+ seq_lens: torch.Tensor,
349
+ seq_lens_sum: int,
350
+ decode_wrapper: BatchMLAPagedAttentionWrapper,
351
+ init_metadata_replay: bool = False,
352
+ **fast_decode_kwargs,
353
+ ):
354
+ decode_wrapper = decode_wrapper or self.decode_wrapper
355
+ self.call_begin_forward(
356
+ decode_wrapper,
357
+ req_pool_indices,
358
+ seq_lens,
359
+ seq_lens_sum,
360
+ self.q_indptr,
361
+ self.kv_indptr,
362
+ init_metadata_replay,
363
+ **fast_decode_kwargs,
364
+ )
365
+
366
+ def call_begin_forward(
367
+ self,
368
+ wrapper: BatchMLAPagedAttentionWrapper,
369
+ req_pool_indices: torch.Tensor,
370
+ paged_kernel_lens: torch.Tensor,
371
+ paged_kernel_lens_sum: int,
372
+ q_indptr: torch.Tensor,
373
+ kv_indptr: torch.Tensor,
374
+ init_metadata_replay: bool = False,
375
+ **fast_decode_kwargs,
376
+ ):
377
+ bs = len(req_pool_indices)
378
+ q_indptr = q_indptr[: bs + 1]
379
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
380
+ kv_indptr = kv_indptr[: bs + 1]
381
+ kv_indices = (
382
+ torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
383
+ if not init_metadata_replay
384
+ else fast_decode_kwargs["kv_indices"]
385
+ )
386
+
387
+ kv_lens = paged_kernel_lens.to(torch.int32)
388
+ sm_scale = self.scaling
389
+
390
+ create_flashinfer_kv_indices_triton[(bs,)](
391
+ self.req_to_token,
392
+ req_pool_indices,
393
+ paged_kernel_lens,
394
+ kv_indptr,
395
+ None,
396
+ kv_indices,
397
+ self.req_to_token.shape[1],
398
+ )
399
+ if not init_metadata_replay:
400
+ wrapper.plan(
401
+ q_indptr,
402
+ kv_indptr,
403
+ kv_indices,
404
+ kv_lens,
405
+ self.num_local_heads,
406
+ self.kv_lora_rank,
407
+ self.qk_rope_head_dim,
408
+ 1,
409
+ False,
410
+ sm_scale,
411
+ self.data_type,
412
+ self.data_type,
413
+ )
414
+ else:
415
+ wrapper.plan(
416
+ fast_decode_kwargs["qo_indptr_cpu"],
417
+ fast_decode_kwargs["kv_indptr_cpu"],
418
+ kv_indices,
419
+ fast_decode_kwargs["kv_len_arr_cpu"],
420
+ self.num_local_heads,
421
+ self.kv_lora_rank,
422
+ self.qk_rope_head_dim,
423
+ 1,
424
+ False,
425
+ sm_scale,
426
+ self.data_type,
427
+ self.data_type,
428
+ )
429
+
430
+
431
+ class FlashInferMLAIndicesUpdaterPrefill:
432
+ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
433
+ # Parse Constants
434
+ self.num_local_heads = (
435
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
436
+ )
437
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
438
+ self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
439
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
440
+ self.v_head_dim = model_runner.model_config.v_head_dim
441
+ self.scaling = model_runner.model_config.scaling
442
+ self.data_type = model_runner.kv_cache_dtype
443
+ self.q_data_type = model_runner.dtype
444
+ self.attn_backend = attn_backend
445
+
446
+ # Buffers and wrappers
447
+ self.kv_indptr = attn_backend.kv_indptr
448
+ self.qo_indptr = attn_backend.qo_indptr
449
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
450
+ self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
451
+
452
+ def update(
453
+ self,
454
+ req_pool_indices: torch.Tnesor,
455
+ seq_lens: torch.Tensor,
456
+ seq_lens_sum: int,
457
+ prefix_lens: torch.Tensor,
458
+ prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
459
+ use_ragged: bool,
460
+ ):
461
+ if use_ragged:
462
+ paged_kernel_lens = prefix_lens
463
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
464
+ else:
465
+ paged_kernel_lens = seq_lens
466
+ paged_kernel_lens_sum = seq_lens_sum
467
+
468
+ self.call_begin_forward(
469
+ self.prefill_wrapper_ragged,
470
+ prefill_wrapper_paged,
471
+ req_pool_indices,
472
+ paged_kernel_lens,
473
+ paged_kernel_lens_sum,
474
+ seq_lens,
475
+ prefix_lens,
476
+ self.kv_indptr,
477
+ self.qo_indptr,
478
+ use_ragged,
479
+ )
480
+
481
+ def call_begin_forward(
482
+ self,
483
+ wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
484
+ wrapper_paged: BatchMLAPagedAttentionWrapper,
485
+ req_pool_indices: torch.Tensor,
486
+ paged_kernel_lens: torch.Tensor,
487
+ paged_kernel_lens_sum: int,
488
+ seq_lens: torch.Tensor,
489
+ prefix_lens: torch.Tensor,
490
+ kv_indptr: torch.Tensor,
491
+ qo_indptr: torch.Tensor,
492
+ use_ragged: bool,
493
+ ):
494
+ bs = len(req_pool_indices)
495
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
496
+ kv_indptr = kv_indptr[: bs + 1]
497
+ kv_indices = torch.empty(
498
+ paged_kernel_lens_sum,
499
+ dtype=torch.int32,
500
+ device=req_pool_indices.device,
501
+ )
502
+ create_flashinfer_kv_indices_triton[(bs,)](
503
+ self.req_to_token,
504
+ req_pool_indices,
505
+ paged_kernel_lens,
506
+ kv_indptr,
507
+ None,
508
+ kv_indices,
509
+ self.req_to_token.shape[1],
510
+ )
511
+
512
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
513
+ qo_indptr = qo_indptr[: bs + 1]
514
+ sm_scale = self.scaling
515
+
516
+ if use_ragged:
517
+ # ragged prefill
518
+ wrapper_ragged.begin_forward(
519
+ qo_indptr=qo_indptr,
520
+ kv_indptr=qo_indptr,
521
+ num_qo_heads=self.num_local_heads,
522
+ num_kv_heads=self.num_local_heads,
523
+ head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
524
+ head_dim_vo=self.v_head_dim,
525
+ q_data_type=self.q_data_type,
526
+ )
527
+ else:
528
+ # mla paged prefill
529
+ kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
530
+ wrapper_paged.plan(
531
+ qo_indptr,
532
+ kv_indptr,
533
+ kv_indices,
534
+ kv_len_arr,
535
+ self.num_local_heads,
536
+ self.kv_lora_rank,
537
+ self.qk_rope_head_dim,
538
+ 1,
539
+ True,
540
+ sm_scale,
541
+ self.q_data_type,
542
+ self.data_type,
543
+ )
544
+
545
+
546
+ def fast_mla_decode_plan(
547
+ self,
548
+ qo_indptr_cpu: torch.Tensor,
549
+ kv_indptr_cpu: torch.Tensor,
550
+ kv_indices: torch.Tensor,
551
+ kv_len_arr_cpu: torch.Tensor,
552
+ num_heads: int,
553
+ head_dim_ckv: int,
554
+ head_dim_kpe: int,
555
+ page_size: int,
556
+ causal: bool,
557
+ sm_scale: float,
558
+ q_data_type: torch.dtype,
559
+ kv_data_type: torch.dtype,
560
+ ) -> None:
561
+ """A faster version of BatchMLAPagedAttentionWrapper::plan,
562
+ for skipping the stream synchronization in original plan function during
563
+ cuda graph replaying.
564
+ """
565
+ self._causal = causal
566
+ self._page_size = page_size
567
+ self._sm_scale = sm_scale
568
+
569
+ with self.device as device:
570
+ stream = torch.cuda.current_stream(device).cuda_stream
571
+ self._cached_module.plan(
572
+ self._float_workspace_buffer,
573
+ self._int_workspace_buffer,
574
+ self._pin_memory_int_workspace_buffer,
575
+ qo_indptr_cpu,
576
+ kv_indptr_cpu,
577
+ kv_len_arr_cpu,
578
+ num_heads,
579
+ head_dim_ckv,
580
+ causal,
581
+ stream,
582
+ )
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
5
5
  import torch
6
6
  from torch.nn.functional import scaled_dot_product_attention
7
7
 
8
- from sglang.srt.layers.attention import AttentionBackend
8
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
9
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
10
10
 
11
11
  if TYPE_CHECKING:
@@ -1,11 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Optional
3
+ from typing import TYPE_CHECKING, Optional, Union
4
4
 
5
5
  import torch
6
6
  import triton
7
7
 
8
- from sglang.srt.layers.attention import AttentionBackend
8
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
9
  from sglang.srt.layers.attention.flashinfer_backend import (
10
10
  create_flashinfer_kv_indices_triton,
11
11
  )
@@ -15,7 +15,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
15
15
  if TYPE_CHECKING:
16
16
  from sglang.srt.layers.radix_attention import RadixAttention
17
17
  from sglang.srt.model_executor.model_runner import ModelRunner
18
- from sglang.srt.speculative.spec_info import SpecInfo
18
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
19
19
 
20
20
 
21
21
  class TritonAttnBackend(AttentionBackend):
@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend):
156
156
  spec_info.generate_attn_arg_prefill(
157
157
  forward_batch.req_pool_indices,
158
158
  forward_batch.seq_lens,
159
+ None,
159
160
  self.req_to_token,
160
161
  )
161
162
  )
@@ -232,7 +233,7 @@ class TritonAttnBackend(AttentionBackend):
232
233
  seq_lens: torch.Tensor,
233
234
  encoder_lens: Optional[torch.Tensor],
234
235
  forward_mode: ForwardMode,
235
- spec_info: Optional[SpecInfo],
236
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
236
237
  ):
237
238
  assert encoder_lens is None, "Not supported"
238
239
 
@@ -310,7 +311,8 @@ class TritonAttnBackend(AttentionBackend):
310
311
  seq_lens_sum: int,
311
312
  encoder_lens: Optional[torch.Tensor],
312
313
  forward_mode: ForwardMode,
313
- spec_info: Optional[SpecInfo],
314
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
315
+ seq_lens_cpu: Optional[torch.Tensor],
314
316
  ):
315
317
  # NOTE: encoder_lens expected to be zeros or None
316
318
  if forward_mode.is_decode_or_idle():
@@ -474,7 +476,7 @@ class TritonMultiStepDraftBackend:
474
476
  self.topk = topk
475
477
  self.speculative_num_steps = speculative_num_steps
476
478
  self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
477
- max_bs = model_runner.req_to_token_pool.size
479
+ max_bs = model_runner.req_to_token_pool.size * self.topk
478
480
  self.kv_indptr = torch.zeros(
479
481
  (
480
482
  self.speculative_num_steps,
@@ -576,16 +578,19 @@ class TritonMultiStepDraftBackend:
576
578
 
577
579
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
578
580
 
579
- def init_forward_metadata_replay_cuda_graph(self, forward_batch):
581
+ def init_forward_metadata_replay_cuda_graph(
582
+ self, forward_batch: ForwardBatch, bs: int
583
+ ):
580
584
  def call_fn(i, forward_batch):
581
585
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
582
- forward_batch.batch_size,
586
+ bs,
583
587
  forward_batch.req_pool_indices,
584
588
  forward_batch.seq_lens,
585
589
  seq_lens_sum=-1,
586
590
  encoder_lens=None,
587
591
  forward_mode=ForwardMode.DECODE,
588
592
  spec_info=forward_batch.spec_info,
593
+ seq_lens_cpu=None,
589
594
  )
590
595
 
591
596
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@@ -635,6 +635,9 @@ def decode_attention_fwd(
635
635
  logit_cap=0.0,
636
636
  ):
637
637
  assert num_kv_splits == attn_logits.shape[2]
638
+ assert q.shape[0] <= kv_indptr.shape[0] - 1
639
+ assert q.shape[0] <= attn_logits.shape[0]
640
+
638
641
  kv_group_num = q.shape[1] // v_buffer.shape[1]
639
642
 
640
643
  if kv_group_num == 1: