sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -13,36 +13,64 @@ from typing import TYPE_CHECKING, Optional, Union
13
13
 
14
14
  import torch
15
15
 
16
+ from sglang.srt.configs.model_config import AttentionArch
16
17
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
18
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
17
19
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
18
20
 
19
21
  if TYPE_CHECKING:
20
22
  from sglang.srt.layers.radix_attention import RadixAttention
21
23
  from sglang.srt.model_executor.model_runner import ModelRunner
22
24
 
23
- from flash_attn_interface import flash_attn_with_kvcache
25
+ from sgl_kernel.flash_attn import flash_attn_with_kvcache
24
26
 
25
27
 
26
28
  @dataclass
27
29
  class FlashAttentionMetadata:
28
- """Metadata for decode operations to avoid redundant computations."""
30
+ """Metadata to be init once in the model forward pass,
31
+ each layer's forward pass can reuse the metadata."""
29
32
 
33
+ # Cumulative sequence lengths for query
30
34
  cu_seqlens_q: torch.Tensor = None
35
+ # Cumulative sequence lengths for key
31
36
  cu_seqlens_k: torch.Tensor = None
37
+ # Maximum sequence length for query
38
+ max_seq_len_q: int = 0
39
+ # Maximum sequence length for key
32
40
  max_seq_len_k: int = 0
41
+ # Window size (typically used by Gemma)
33
42
  window_size: tuple = (-1, -1)
43
+ # Page table, the index of KV Cache Tables/Blocks
34
44
  page_table: torch.Tensor = None
45
+ # Sequence lengths for the forward batch
35
46
  cache_seqlens_int32: torch.Tensor = None
36
- max_seq_len_q: int = 0
37
47
 
38
48
 
39
49
  class FlashAttentionBackend(AttentionBackend):
40
- """FlashAttention backend implementation."""
50
+ """FlashAttention backend implementation.
51
+
52
+ Note about the init:
53
+ - If no spec decoding
54
+ - FlashAttentionBackend will be init once when the server starts.
55
+ - If spec decoding
56
+ - FlashAttentionBackend will be init once for the target worker
57
+ - FlashAttentionMultiStepBackend will be once for the draft worker
58
+ - It will spawn num_steps FlashAttentionBackend for the draft worker
59
+
60
+ Note about CUDA Graph:
61
+ - We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
62
+ - We don't support CUDA Graph for Extend and Draft Extend.
63
+ - When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
64
+ - For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
65
+ """
41
66
 
42
67
  def __init__(
43
68
  self,
44
69
  model_runner: ModelRunner,
45
70
  skip_prefill: bool = False,
71
+ topk=0,
72
+ speculative_num_steps=0,
73
+ step_id=0,
46
74
  ):
47
75
  super().__init__()
48
76
 
@@ -51,49 +79,138 @@ class FlashAttentionBackend(AttentionBackend):
51
79
  and model_runner.model_config.is_encoder_decoder
52
80
  ), "Sliding window and cross attention are not supported together"
53
81
 
54
- # Initialize metadata
55
82
  self.forward_metadata: FlashAttentionMetadata = None
56
83
  self.max_context_len = model_runner.model_config.context_len
57
84
  self.device = model_runner.device
58
85
  self.decode_cuda_graph_metadata = {}
86
+ self.target_verify_metadata = {}
59
87
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
88
+ self.page_size = model_runner.page_size
89
+ self.use_mla = (
90
+ model_runner.model_config.attention_arch == AttentionArch.MLA
91
+ ) and (not global_server_args_dict["disable_mla"])
92
+ self.skip_prefill = skip_prefill
93
+
94
+ # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
95
+ assert (
96
+ topk <= 1
97
+ ), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
98
+
99
+ self.topk = 1
100
+ self.step_id = step_id
101
+ self.speculative_num_steps = speculative_num_steps
60
102
 
61
103
  def init_forward_metadata(self, forward_batch: ForwardBatch):
62
104
  """Initialize forward metadata to cache repetitive calculations."""
63
- # Create metadata based on forward mode
64
105
  metadata = FlashAttentionMetadata()
65
-
66
- extend_seq_lens = forward_batch.extend_seq_lens
67
- # Get sequence information
68
106
  seqlens_in_batch = forward_batch.seq_lens
69
- # Precompute int32 version of sequence lengths
70
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
71
107
  batch_size = len(seqlens_in_batch)
72
108
  device = seqlens_in_batch.device
73
- metadata.cu_seqlens_k = torch.nn.functional.pad(
74
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
75
- )
76
- # Precompute maximum sequence length
77
- metadata.max_seq_len_k = seqlens_in_batch.max().item()
78
- # Precompute page table
79
- metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
80
- forward_batch.req_pool_indices, : metadata.max_seq_len_k
81
- ]
82
- if forward_batch.forward_mode == ForwardMode.DECODE:
83
- # Precompute cumulative sequence lengths
109
+ if forward_batch.forward_mode.is_decode():
110
+ # Skip Prefill or Draft Decode
111
+ # Note: Draft Decode will be ran on the Draft Worker
112
+ if forward_batch.spec_info is not None:
113
+ metadata.cu_seqlens_q = torch.arange(
114
+ 0, batch_size + 1, dtype=torch.int32, device=device
115
+ )
116
+ seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
117
+ metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
118
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
119
+ torch.cumsum(
120
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
121
+ ),
122
+ (1, 0),
123
+ )
124
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
125
+ self.step_id + 1
126
+ )
127
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
128
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
129
+ ]
130
+ cache_loc = forward_batch.out_cache_loc.view(
131
+ self.speculative_num_steps, -1
132
+ ).T
133
+
134
+ for idx, single_seq_len in enumerate(seq_lens_with_decode):
135
+ real_bsz_start_idx = idx
136
+ real_bsz_end_idx = idx + 1
137
+ metadata.page_table[
138
+ real_bsz_start_idx:real_bsz_end_idx,
139
+ (single_seq_len - (self.step_id + 1)) : single_seq_len,
140
+ ] = cache_loc[
141
+ real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
142
+ ]
143
+ else: # Normal Decode without Spec Decoding
144
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
145
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
146
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
147
+ )
148
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
149
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
150
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
151
+ ]
152
+ metadata.cu_seqlens_q = torch.arange(
153
+ 0, batch_size + 1, dtype=torch.int32, device=device
154
+ )
155
+ elif forward_batch.forward_mode.is_target_verify():
156
+ # Note: Target Verify will be ran on the Target Worker
157
+ draft_token_num = forward_batch.spec_info.draft_token_num
158
+ metadata.cache_seqlens_int32 = (
159
+ forward_batch.seq_lens + draft_token_num
160
+ ).to(torch.int32)
161
+ metadata.max_seq_len_q = draft_token_num
162
+ metadata.max_seq_len_k = (
163
+ forward_batch.seq_lens_cpu.max().item() + draft_token_num
164
+ )
84
165
  metadata.cu_seqlens_q = torch.arange(
85
- 0, batch_size + 1, dtype=torch.int32, device=device
166
+ 0,
167
+ batch_size * draft_token_num + 1,
168
+ draft_token_num,
169
+ dtype=torch.int32,
170
+ device=device,
86
171
  )
87
- else:
88
- extend_no_prefix = not any(forward_batch.extend_prefix_lens)
172
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
173
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
174
+ (1, 0),
175
+ )
176
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
177
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
178
+ ]
179
+
180
+ elif forward_batch.forward_mode.is_extend_or_draft_extend():
181
+ # Normal or Draft Extend (Both of them will be ran on the Target Worker)
182
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
183
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
184
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
185
+ )
186
+ # Precompute maximum sequence length
187
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
188
+ # Precompute page table
189
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
190
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
191
+ ]
89
192
  # Precompute cumulative sequence lengths
90
- if not extend_no_prefix:
193
+ if (
194
+ any(forward_batch.extend_prefix_lens_cpu)
195
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
196
+ ):
197
+ extend_seq_lens = forward_batch.extend_seq_lens
91
198
  metadata.cu_seqlens_q = torch.nn.functional.pad(
92
199
  torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
93
200
  )
201
+ metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
94
202
  else:
95
203
  metadata.cu_seqlens_q = metadata.cu_seqlens_k
96
- metadata.max_seq_len_q = seqlens_in_batch.max().item()
204
+ metadata.max_seq_len_q = metadata.max_seq_len_k
205
+
206
+ # Precompute strided indices
207
+ if self.page_size > 1:
208
+ self.strided_indices = torch.arange(
209
+ 0, metadata.page_table.shape[1], self.page_size, device=self.device
210
+ )
211
+ metadata.page_table = (
212
+ metadata.page_table[:, self.strided_indices] // self.page_size
213
+ )
97
214
  self.forward_metadata = metadata
98
215
 
99
216
  def forward_extend(
@@ -105,23 +222,29 @@ class FlashAttentionBackend(AttentionBackend):
105
222
  forward_batch: ForwardBatch,
106
223
  save_kv_cache=True,
107
224
  ):
108
- cache_loc = (
109
- forward_batch.out_cache_loc
110
- if not layer.is_cross_attention
111
- else forward_batch.encoder_out_cache_loc
112
- )
113
-
114
225
  if k is not None:
115
226
  assert v is not None
116
227
  if save_kv_cache:
117
- forward_batch.token_to_kv_pool.set_kv_buffer(
118
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
228
+ cache_loc = (
229
+ forward_batch.out_cache_loc
230
+ if not layer.is_cross_attention
231
+ else forward_batch.encoder_out_cache_loc
119
232
  )
233
+ if not self.use_mla:
234
+ forward_batch.token_to_kv_pool.set_kv_buffer(
235
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
236
+ )
237
+ else:
238
+ forward_batch.token_to_kv_pool.set_kv_buffer(
239
+ layer,
240
+ cache_loc,
241
+ k,
242
+ v,
243
+ )
120
244
 
121
245
  # Use precomputed metadata
122
246
  metadata = self.forward_metadata
123
247
 
124
- # # Use Flash Attention for prefill
125
248
  # Calculate window size (can be moved to metadata if layer properties don't change)
126
249
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
127
250
  # here is two side inclusive
@@ -130,26 +253,72 @@ class FlashAttentionBackend(AttentionBackend):
130
253
  if layer.sliding_window_size is not None
131
254
  else (-1, -1)
132
255
  )
133
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
134
- key_cache, value_cache = kv_cache[0], kv_cache[1]
135
- o = flash_attn_with_kvcache(
136
- q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
137
- k_cache=key_cache.unsqueeze(1),
138
- v_cache=value_cache.unsqueeze(1),
139
- page_table=metadata.page_table,
140
- cache_seqlens=metadata.cache_seqlens_int32,
141
- cu_seqlens_q=metadata.cu_seqlens_q,
142
- cu_seqlens_k_new=metadata.cu_seqlens_k,
143
- max_seqlen_q=metadata.max_seq_len_q,
144
- softmax_scale=layer.scaling,
145
- causal=True,
146
- window_size=window_size,
147
- softcap=layer.logit_cap,
148
- k_descale=layer.k_scale,
149
- v_descale=layer.v_scale,
150
- )
151
256
 
152
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
257
+ page_table = metadata.page_table
258
+
259
+ # Use Flash Attention for prefill
260
+ if not self.use_mla:
261
+ # Do multi-head attention
262
+ kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
263
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
264
+ key_cache = key_cache.view(
265
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
266
+ )
267
+ value_cache = value_cache.view(
268
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
269
+ )
270
+ o = flash_attn_with_kvcache(
271
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
272
+ k_cache=key_cache,
273
+ v_cache=value_cache,
274
+ page_table=page_table,
275
+ cache_seqlens=metadata.cache_seqlens_int32,
276
+ cu_seqlens_q=metadata.cu_seqlens_q,
277
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
278
+ max_seqlen_q=metadata.max_seq_len_q,
279
+ softmax_scale=layer.scaling,
280
+ causal=True,
281
+ window_size=window_size,
282
+ softcap=layer.logit_cap,
283
+ k_descale=layer.k_scale,
284
+ v_descale=layer.v_scale,
285
+ )
286
+ else:
287
+ # Do absorbed multi-latent attention
288
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
289
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
290
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
291
+ k_rope_cache = k_rope.view(
292
+ -1,
293
+ self.page_size,
294
+ layer.tp_k_head_num,
295
+ layer.head_dim - layer.v_head_dim,
296
+ )
297
+ c_kv_cache = c_kv.view(
298
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
299
+ )
300
+
301
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
302
+ q_nope = q_all[:, :, : layer.v_head_dim]
303
+ q_rope = q_all[:, :, layer.v_head_dim :]
304
+ o = flash_attn_with_kvcache(
305
+ q=q_rope,
306
+ k_cache=k_rope_cache,
307
+ v_cache=c_kv_cache,
308
+ qv=q_nope,
309
+ page_table=page_table,
310
+ cache_seqlens=metadata.cache_seqlens_int32,
311
+ cu_seqlens_q=metadata.cu_seqlens_q,
312
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
313
+ max_seqlen_q=metadata.max_seq_len_q,
314
+ softmax_scale=layer.scaling,
315
+ causal=True,
316
+ softcap=layer.logit_cap,
317
+ k_descale=layer.k_scale,
318
+ v_descale=layer.v_scale,
319
+ )
320
+
321
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
153
322
 
154
323
  def forward_decode(
155
324
  self,
@@ -162,26 +331,29 @@ class FlashAttentionBackend(AttentionBackend):
162
331
  ) -> torch.Tensor:
163
332
  """Forward pass with FlashAttention using precomputed metadata."""
164
333
  # Save KV cache if needed
165
- if k is not None and v is not None and save_kv_cache:
166
- cache_loc = (
167
- forward_batch.out_cache_loc
168
- if not layer.is_cross_attention
169
- else forward_batch.encoder_out_cache_loc
170
- )
171
- forward_batch.token_to_kv_pool.set_kv_buffer(
172
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
173
- )
174
-
175
- # Get KV cache
176
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
177
- key_cache, value_cache = kv_cache[0], kv_cache[1]
334
+ if k is not None:
335
+ assert v is not None
336
+ if save_kv_cache:
337
+ cache_loc = (
338
+ forward_batch.out_cache_loc
339
+ if not layer.is_cross_attention
340
+ else forward_batch.encoder_out_cache_loc
341
+ )
342
+ if not self.use_mla:
343
+ forward_batch.token_to_kv_pool.set_kv_buffer(
344
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
345
+ )
346
+ else:
347
+ forward_batch.token_to_kv_pool.set_kv_buffer(
348
+ layer,
349
+ cache_loc,
350
+ k,
351
+ v,
352
+ )
178
353
 
179
354
  # Use precomputed metadata
180
355
  metadata = self.forward_metadata
181
356
 
182
- # Pre-reshape query tensor
183
- q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
184
-
185
357
  # Calculate window size (can be moved to metadata if layer properties don't change)
186
358
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
187
359
  # here is two side inclusive
@@ -190,25 +362,75 @@ class FlashAttentionBackend(AttentionBackend):
190
362
  if layer.sliding_window_size is not None
191
363
  else (-1, -1)
192
364
  )
193
- # Run attention with precomputed values
194
- o = flash_attn_with_kvcache(
195
- q=q_reshaped,
196
- k_cache=key_cache.unsqueeze(1),
197
- v_cache=value_cache.unsqueeze(1),
198
- page_table=metadata.page_table,
199
- cache_seqlens=metadata.cache_seqlens_int32,
200
- cu_seqlens_q=metadata.cu_seqlens_q,
201
- cu_seqlens_k_new=metadata.cu_seqlens_k,
202
- max_seqlen_q=1,
203
- softmax_scale=layer.scaling,
204
- causal=True,
205
- window_size=window_size,
206
- softcap=layer.logit_cap,
207
- k_descale=layer.k_scale,
208
- v_descale=layer.v_scale,
209
- )
365
+ page_table = metadata.page_table
366
+
367
+ if not self.use_mla:
368
+ # Do multi-head attention
369
+
370
+ # Get KV cache
371
+ kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
372
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
373
+ key_cache = key_cache.view(
374
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
375
+ )
376
+ value_cache = value_cache.view(
377
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
378
+ )
210
379
 
211
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
380
+ # Pre-reshape query tensor
381
+ q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
382
+ o = flash_attn_with_kvcache(
383
+ q=q_reshaped,
384
+ k_cache=key_cache,
385
+ v_cache=value_cache,
386
+ page_table=page_table,
387
+ cache_seqlens=metadata.cache_seqlens_int32,
388
+ cu_seqlens_q=metadata.cu_seqlens_q,
389
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
390
+ max_seqlen_q=1,
391
+ softmax_scale=layer.scaling,
392
+ causal=True,
393
+ window_size=window_size,
394
+ softcap=layer.logit_cap,
395
+ k_descale=layer.k_scale,
396
+ v_descale=layer.v_scale,
397
+ )
398
+ else:
399
+ # Do absorbed multi-latent attention
400
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
401
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
402
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
403
+ k_rope_cache = k_rope.view(
404
+ -1,
405
+ self.page_size,
406
+ layer.tp_k_head_num,
407
+ layer.head_dim - layer.v_head_dim,
408
+ )
409
+ c_kv_cache = c_kv.view(
410
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
411
+ )
412
+
413
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
414
+ q_nope = q_all[:, :, : layer.v_head_dim]
415
+ q_rope = q_all[:, :, layer.v_head_dim :]
416
+
417
+ o = flash_attn_with_kvcache(
418
+ q=q_rope,
419
+ k_cache=k_rope_cache,
420
+ v_cache=c_kv_cache,
421
+ qv=q_nope,
422
+ page_table=page_table,
423
+ cache_seqlens=metadata.cache_seqlens_int32,
424
+ cu_seqlens_q=metadata.cu_seqlens_q,
425
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
426
+ max_seqlen_q=1,
427
+ softmax_scale=layer.scaling,
428
+ causal=True,
429
+ softcap=layer.logit_cap,
430
+ k_descale=layer.k_scale,
431
+ v_descale=layer.v_scale,
432
+ )
433
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
212
434
 
213
435
  def init_cuda_graph_state(self, max_bs: int):
214
436
  """Initialize CUDA graph state for the attention backend.
@@ -219,11 +441,49 @@ class FlashAttentionBackend(AttentionBackend):
219
441
  This creates fixed-size tensors that will be reused during CUDA graph replay
220
442
  to avoid memory allocations.
221
443
  """
222
- # Initialize fixed size tensors for decode operations
223
444
  self.decode_cuda_graph_metadata = {
224
445
  # Page table for token mapping (batch_size, max_context_len)
225
446
  "page_table": torch.zeros(
226
- max_bs, self.max_context_len, dtype=torch.int32, device=self.device
447
+ max_bs,
448
+ (self.max_context_len + self.page_size - 1) // self.page_size,
449
+ dtype=torch.int32,
450
+ device=self.device,
451
+ ),
452
+ "page_table_draft_decode": torch.zeros(
453
+ max_bs,
454
+ (self.max_context_len + self.page_size - 1) // self.page_size,
455
+ dtype=torch.int32,
456
+ device=self.device,
457
+ ),
458
+ "strided_indices": torch.arange(
459
+ 0, self.max_context_len, self.page_size, device=self.device
460
+ ),
461
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
462
+ "cu_seqlens_q": torch.arange(
463
+ 0, max_bs + 128, dtype=torch.int32, device=self.device
464
+ ),
465
+ "cu_seqlens_k": torch.zeros(
466
+ max_bs + 128, dtype=torch.int32, device=self.device
467
+ ),
468
+ }
469
+
470
+ self.target_verify_metadata = {
471
+ "page_table": torch.zeros(
472
+ max_bs,
473
+ (self.max_context_len + self.page_size - 1) // self.page_size,
474
+ dtype=torch.int32,
475
+ device=self.device,
476
+ ),
477
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
478
+ "cu_seqlens_q": torch.zeros(
479
+ max_bs + 128, dtype=torch.int32, device=self.device
480
+ ),
481
+ "cu_seqlens_k": torch.zeros(
482
+ max_bs + 128, dtype=torch.int32, device=self.device
483
+ ),
484
+ "max_seqlen_q": 0,
485
+ "strided_indices": torch.arange(
486
+ 0, self.max_context_len, self.page_size, device=self.device
227
487
  ),
228
488
  }
229
489
 
@@ -239,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend):
239
499
  ):
240
500
  """Initialize forward metadata for capturing CUDA graph."""
241
501
  metadata = FlashAttentionMetadata()
242
- # Get sequence information
243
- metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
244
- batch_size = len(seq_lens)
245
502
  device = seq_lens.device
246
- metadata.cu_seqlens_k = torch.nn.functional.pad(
247
- torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
248
- )
249
- # Precompute maximum sequence length
250
- metadata.max_seq_len_k = seq_lens.max().item()
251
- # Precompute page table
252
- metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
253
- req_pool_indices, :
254
- ]
255
- if forward_mode == ForwardMode.DECODE:
256
- # Precompute cumulative sequence lengths
257
- metadata.cu_seqlens_q = torch.arange(
258
- 0, batch_size + 1, dtype=torch.int32, device=device
503
+ if forward_mode.is_decode():
504
+ if spec_info is not None:
505
+ # Draft Decode
506
+ metadata.cu_seqlens_q = torch.arange(
507
+ 0, bs + 1, dtype=torch.int32, device=device
508
+ )
509
+ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
510
+ "cache_seqlens"
511
+ ][:bs]
512
+
513
+ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
514
+ : bs + 1
515
+ ]
516
+
517
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
518
+ torch.cumsum(
519
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
520
+ ),
521
+ (1, 0),
522
+ )
523
+ metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
524
+ metadata.page_table = self.decode_cuda_graph_metadata[
525
+ "page_table_draft_decode"
526
+ ][req_pool_indices, :]
527
+ else:
528
+ # Normal Decode
529
+ # Get sequence information
530
+ metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
531
+ batch_size = len(seq_lens)
532
+ device = seq_lens.device
533
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
534
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
535
+ )
536
+ # Precompute maximum sequence length
537
+ metadata.max_seq_len_k = seq_lens.max().item()
538
+ # Precompute page table
539
+ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
540
+ req_pool_indices, :
541
+ ]
542
+ # Precompute cumulative sequence lengths
543
+ metadata.cu_seqlens_q = torch.arange(
544
+ 0, batch_size + 1, dtype=torch.int32, device=device
545
+ )
546
+ self.decode_cuda_graph_metadata[bs] = metadata
547
+ elif forward_mode.is_target_verify():
548
+ draft_token_num = spec_info.draft_token_num
549
+
550
+ metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
551
+ :bs
552
+ ]
553
+ metadata.cache_seqlens_int32.copy_(
554
+ (seq_lens + draft_token_num).to(torch.int32)
259
555
  )
260
- else:
261
- raise ValueError("Do not support Prefill Mode cuda graph")
262
- self.decode_cuda_graph_metadata[bs] = metadata
556
+
557
+ metadata.max_seq_len_q = draft_token_num
558
+ metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
559
+
560
+ metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
561
+ torch.arange(
562
+ 0,
563
+ bs * draft_token_num + 1,
564
+ draft_token_num,
565
+ dtype=torch.int32,
566
+ device=device,
567
+ )
568
+ ]
569
+ cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
570
+ cu_k.copy_(
571
+ torch.nn.functional.pad(
572
+ torch.cumsum(
573
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
574
+ ),
575
+ (1, 0),
576
+ )
577
+ )
578
+ metadata.cu_seqlens_k = cu_k
579
+ metadata.page_table = self.target_verify_metadata["page_table"][
580
+ req_pool_indices, :
581
+ ]
582
+
583
+ self.target_verify_metadata[bs] = metadata
584
+
263
585
  self.forward_metadata = metadata
264
586
 
265
587
  def init_forward_metadata_replay_cuda_graph(
@@ -272,24 +594,159 @@ class FlashAttentionBackend(AttentionBackend):
272
594
  forward_mode: ForwardMode,
273
595
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
274
596
  seq_lens_cpu: Optional[torch.Tensor],
597
+ out_cache_loc: torch.Tensor = None,
275
598
  ):
276
599
  # """Initialize forward metadata for replaying CUDA graph."""
277
- seqlens_in_batch = seq_lens[:bs]
278
- metadata = self.decode_cuda_graph_metadata[bs]
279
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
280
- metadata.cu_seqlens_k = torch.nn.functional.pad(
281
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
282
- )
283
- # Precompute maximum sequence length
284
- metadata.max_seq_len_k = seqlens_in_batch.max().item()
285
- # Only zero out the part out of max_len_k
286
- metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
287
- # Then do the copy
288
- metadata.page_table[:, : metadata.max_seq_len_k].copy_(
289
- self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
290
- )
291
- self.forward_decode_metadata = metadata
600
+ device = seq_lens.device
601
+ seq_lens = seq_lens[:bs]
602
+ req_pool_indices = req_pool_indices[:bs]
603
+ seq_lens_cpu = seq_lens_cpu[:bs]
604
+ if forward_mode.is_decode():
605
+ metadata = self.decode_cuda_graph_metadata[bs]
606
+
607
+ if spec_info is not None:
608
+ # Draft Decode
609
+ max_len = seq_lens_cpu.max().item()
610
+ metadata.max_seq_len_k = max_len + (self.step_id + 1)
611
+
612
+ metadata.cache_seqlens_int32.copy_(
613
+ (seq_lens + (self.step_id + 1)).to(torch.int32)
614
+ )
615
+
616
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
617
+
618
+ metadata.cu_seqlens_k.copy_(
619
+ torch.nn.functional.pad(
620
+ torch.cumsum(
621
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
622
+ ),
623
+ (1, 0),
624
+ )
625
+ )
626
+
627
+ page_table = self.req_to_token[
628
+ req_pool_indices, : metadata.max_seq_len_k
629
+ ]
630
+
631
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
632
+ else:
633
+ # Normal Decode
634
+ max_len = seq_lens_cpu.max().item()
635
+ metadata.max_seq_len_k = max_len
636
+
637
+ metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
638
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
639
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
640
+ )
641
+
642
+ max_seq_pages = (
643
+ metadata.max_seq_len_k + self.page_size - 1
644
+ ) // self.page_size
645
+ page_indices = self.req_to_token[
646
+ :,
647
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
648
+ ]
649
+ page_indices = page_indices[req_pool_indices] // self.page_size
650
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
651
+ metadata.page_table[:, max_seq_pages:].fill_(0)
652
+
653
+ elif forward_mode.is_target_verify():
654
+ metadata = self.target_verify_metadata[bs]
655
+ draft_token_num = spec_info.draft_token_num
656
+
657
+ metadata.cu_seqlens_q.copy_(
658
+ torch.arange(
659
+ 0,
660
+ bs * draft_token_num + 1,
661
+ draft_token_num,
662
+ dtype=torch.int32,
663
+ device=device,
664
+ )
665
+ )
666
+ metadata.cache_seqlens_int32.copy_(
667
+ (seq_lens + draft_token_num).to(torch.int32)
668
+ )
669
+
670
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
671
+ metadata.cu_seqlens_k.copy_(
672
+ torch.nn.functional.pad(
673
+ torch.cumsum(
674
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
675
+ ),
676
+ (1, 0),
677
+ )
678
+ )
679
+ page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
680
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
681
+
682
+ self.forward_metadata = metadata
292
683
 
293
684
  def get_cuda_graph_seq_len_fill_value(self):
294
685
  """Get the fill value for sequence length in CUDA graph."""
295
686
  return 0
687
+
688
+
689
+ class FlashAttentionMultiStepBackend:
690
+
691
+ def __init__(
692
+ self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
693
+ ):
694
+ self.model_runner = model_runner
695
+ self.topk = topk
696
+ self.speculative_num_steps = speculative_num_steps
697
+
698
+ self.attn_backends = []
699
+ for i in range(self.speculative_num_steps):
700
+ self.attn_backends.append(
701
+ FlashAttentionBackend(
702
+ model_runner,
703
+ topk=self.topk,
704
+ speculative_num_steps=self.speculative_num_steps,
705
+ step_id=i,
706
+ )
707
+ )
708
+
709
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
710
+ for i in range(self.speculative_num_steps - 1):
711
+ self.attn_backends[i].init_forward_metadata(forward_batch)
712
+
713
+ def init_cuda_graph_state(self, max_bs: int):
714
+ for i in range(self.speculative_num_steps):
715
+ self.attn_backends[i].init_cuda_graph_state(max_bs)
716
+
717
+ def init_forward_metadata_capture_cuda_graph(
718
+ self,
719
+ forward_batch: ForwardBatch,
720
+ ):
721
+ assert forward_batch.spec_info is not None
722
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
723
+
724
+ for i in range(self.speculative_num_steps - 1):
725
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
726
+ forward_batch.batch_size,
727
+ forward_batch.batch_size * self.topk,
728
+ forward_batch.req_pool_indices,
729
+ forward_batch.seq_lens,
730
+ encoder_lens=None,
731
+ forward_mode=ForwardMode.DECODE,
732
+ spec_info=forward_batch.spec_info,
733
+ )
734
+
735
+ def init_forward_metadata_replay_cuda_graph(
736
+ self, forward_batch: ForwardBatch, bs: int
737
+ ):
738
+ assert forward_batch.spec_info is not None
739
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
740
+
741
+ for i in range(self.speculative_num_steps - 1):
742
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
743
+ bs,
744
+ forward_batch.req_pool_indices,
745
+ forward_batch.seq_lens,
746
+ forward_batch.seq_lens_sum,
747
+ encoder_lens=None,
748
+ forward_mode=ForwardMode.DECODE,
749
+ spec_info=forward_batch.spec_info,
750
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
751
+ out_cache_loc=forward_batch.out_cache_loc,
752
+ )