sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/srt/layers/attention/__init__.py +14 -5
  3. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  4. sglang/srt/layers/attention/flashinfer_backend.py +211 -81
  5. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  6. sglang/srt/layers/attention/triton_backend.py +20 -11
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  8. sglang/srt/layers/logits_processor.py +167 -212
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
  31. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
  32. sglang/srt/layers/quantization/fp8.py +2 -2
  33. sglang/srt/layers/sampler.py +57 -21
  34. sglang/srt/layers/torchao_utils.py +17 -3
  35. sglang/srt/managers/io_struct.py +1 -2
  36. sglang/srt/managers/schedule_batch.py +26 -2
  37. sglang/srt/managers/schedule_policy.py +159 -90
  38. sglang/srt/managers/scheduler.py +62 -26
  39. sglang/srt/managers/tokenizer_manager.py +22 -20
  40. sglang/srt/managers/tp_worker.py +16 -4
  41. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  42. sglang/srt/model_executor/cuda_graph_runner.py +118 -73
  43. sglang/srt/model_executor/forward_batch_info.py +33 -8
  44. sglang/srt/model_executor/model_runner.py +63 -61
  45. sglang/srt/models/deepseek_v2.py +34 -7
  46. sglang/srt/models/grok.py +97 -26
  47. sglang/srt/openai_api/adapter.py +0 -17
  48. sglang/srt/openai_api/protocol.py +3 -3
  49. sglang/srt/sampling/sampling_batch_info.py +21 -0
  50. sglang/srt/sampling/sampling_params.py +9 -1
  51. sglang/srt/server.py +9 -5
  52. sglang/srt/server_args.py +108 -57
  53. sglang/srt/speculative/build_eagle_tree.py +347 -0
  54. sglang/srt/speculative/eagle_utils.py +618 -0
  55. sglang/srt/speculative/eagle_worker.py +170 -0
  56. sglang/srt/speculative/spec_info.py +5 -0
  57. sglang/srt/utils.py +15 -2
  58. sglang/version.py +1 -1
  59. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
  60. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/RECORD +63 -39
  61. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
  62. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
  63. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -63,6 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
63
63
  from sglang.srt.sampling.sampling_params import SamplingParams
64
64
  from sglang.srt.server import _set_envs_and_config
65
65
  from sglang.srt.server_args import PortArgs, ServerArgs
66
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
66
67
  from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
67
68
 
68
69
 
@@ -214,6 +215,7 @@ def extend(reqs, model_runner):
214
215
  tree_cache=None,
215
216
  model_config=model_runner.model_config,
216
217
  enable_overlap=False,
218
+ spec_algorithm=SpeculativeAlgorithm.NONE,
217
219
  )
218
220
  batch.prepare_for_extend()
219
221
  model_worker_batch = batch.get_model_worker_batch()
@@ -1,10 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
- from typing import Optional
4
+ from typing import TYPE_CHECKING, Optional
3
5
 
4
6
  import torch
5
7
 
6
- from sglang.srt.layers.radix_attention import RadixAttention
7
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
8
+ if TYPE_CHECKING:
9
+ from sglang.srt.layers.radix_attention import RadixAttention
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
11
+ from sglang.srt.speculative.spec_info import SpecInfo
8
12
 
9
13
 
10
14
  class AttentionBackend(ABC):
@@ -22,9 +26,12 @@ class AttentionBackend(ABC):
22
26
  def init_forward_metadata_capture_cuda_graph(
23
27
  self,
24
28
  bs: int,
29
+ num_tokens: int,
25
30
  req_pool_indices: torch.Tensor,
26
31
  seq_lens: torch.Tensor,
27
- encoder_lens: Optional[torch.Tensor] = None,
32
+ encoder_lens: Optional[torch.Tensor],
33
+ forward_mode: ForwardMode,
34
+ spec_info: Optional[SpecInfo],
28
35
  ):
29
36
  """Init the metadata for a forward pass for capturing a cuda graph."""
30
37
  raise NotImplementedError()
@@ -35,7 +42,9 @@ class AttentionBackend(ABC):
35
42
  req_pool_indices: torch.Tensor,
36
43
  seq_lens: torch.Tensor,
37
44
  seq_lens_sum: int,
38
- encoder_lens: Optional[torch.Tensor] = None,
45
+ encoder_lens: Optional[torch.Tensor],
46
+ forward_mode: ForwardMode,
47
+ spec_info: Optional[SpecInfo],
39
48
  ):
40
49
  """Init the metadata for a forward pass for replying a cuda graph."""
41
50
  raise NotImplementedError()
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  import torch
6
- import torch.nn as nn
7
6
 
8
7
  from sglang.srt.layers.attention import AttentionBackend
9
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
52
51
 
53
52
  self.forward_metadata = None
54
53
 
55
- self.cuda_graph_max_seq_len = model_runner.model_config.context_len
56
-
57
54
  def init_forward_metadata(self, forward_batch: ForwardBatch):
58
55
  """Init auxiliary variables for triton attention backend."""
59
56
 
@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
115
112
  ds_req_to_token,
116
113
  )
117
114
 
118
- def init_cuda_graph_state(self, max_bs: int):
119
- # TODO(Andy): Support CUDA graph for double sparse attention
120
- raise ValueError(
121
- "Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
122
- )
123
- self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
124
-
125
- self.cuda_graph_start_loc = torch.zeros(
126
- (max_bs,), dtype=torch.int32, device="cuda"
127
- )
128
- self.cuda_graph_attn_logits = torch.empty(
129
- (
130
- self.num_head,
131
- self.cuda_graph_max_total_num_tokens,
132
- ),
133
- dtype=self.reduce_dtype,
134
- device="cuda",
135
- )
136
-
137
- def init_forward_metadata_capture_cuda_graph(
138
- self,
139
- bs: int,
140
- req_pool_indices: torch.Tensor,
141
- seq_lens: torch.Tensor,
142
- encoder_lens=None,
143
- ):
144
- # NOTE: encoder_lens expected to be zeros or None
145
- self.forward_metadata = (
146
- self.cuda_graph_start_loc,
147
- self.cuda_graph_attn_logits,
148
- self.cuda_graph_max_seq_len,
149
- None,
150
- )
151
-
152
- def init_forward_metadata_replay_cuda_graph(
153
- self,
154
- bs: int,
155
- req_pool_indices: torch.Tensor,
156
- seq_lens: torch.Tensor,
157
- seq_lens_sum: int,
158
- encoder_lens=None,
159
- ):
160
- # NOTE: encoder_lens expected to be zeros or None
161
- self.cuda_graph_start_loc.zero_()
162
- self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
163
-
164
- def get_cuda_graph_seq_len_fill_value(self):
165
- return 1
166
-
167
115
  def forward_extend(
168
116
  self,
169
117
  q,
@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
10
10
  import os
11
11
  from dataclasses import dataclass
12
12
  from enum import Enum, auto
13
- from typing import TYPE_CHECKING, List, Union
13
+ from typing import TYPE_CHECKING, List, Optional, Union
14
14
 
15
15
  import torch
16
16
  import triton
@@ -18,12 +18,13 @@ import triton.language as tl
18
18
 
19
19
  from sglang.global_config import global_config
20
20
  from sglang.srt.layers.attention import AttentionBackend
21
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
22
22
  from sglang.srt.utils import is_flashinfer_available
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  from sglang.srt.layers.radix_attention import RadixAttention
26
26
  from sglang.srt.model_executor.model_runner import ModelRunner
27
+ from sglang.srt.speculative.spec_info import SpecInfo
27
28
 
28
29
  if is_flashinfer_available():
29
30
  from flashinfer import (
@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
113
114
  # Two wrappers: one for sliding window attention and one for full attention.
114
115
  # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
115
116
  self.prefill_wrappers_paged = []
117
+ self.prefill_wrappers_verify = []
116
118
  self.decode_wrappers = []
117
119
  for _ in range(self.num_wrappers):
118
120
  self.prefill_wrappers_paged.append(
119
121
  BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
120
122
  )
123
+ self.prefill_wrappers_verify.append(
124
+ BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
125
+ )
121
126
  self.decode_wrappers.append(
122
127
  BatchDecodeWithPagedKVCacheWrapper(
123
128
  self.workspace_buffer,
@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
135
140
  # Other metadata
136
141
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
137
142
  self.decode_cuda_graph_metadata = {}
143
+ self.prefill_cuda_graph_metadata = {}
138
144
 
139
145
  def init_forward_metadata(self, forward_batch: ForwardBatch):
140
146
  if forward_batch.forward_mode.is_decode():
@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
144
150
  forward_batch.seq_lens_sum,
145
151
  decode_wrappers=self.decode_wrappers,
146
152
  encoder_lens=forward_batch.encoder_lens,
153
+ spec_info=forward_batch.spec_info,
147
154
  )
148
155
  self.forward_metadata = DecodeMetadata(self.decode_wrappers)
156
+ elif forward_batch.forward_mode.is_draft_extend():
157
+ self.indices_updater_prefill.update(
158
+ forward_batch.req_pool_indices,
159
+ forward_batch.seq_lens,
160
+ forward_batch.seq_lens_sum,
161
+ prefix_lens=None,
162
+ prefill_wrappers=self.prefill_wrappers_paged,
163
+ use_ragged=False,
164
+ encoder_lens=forward_batch.encoder_lens,
165
+ spec_info=forward_batch.spec_info,
166
+ )
167
+ self.forward_metadata = PrefillMetadata(
168
+ self.prefill_wrappers_paged, False, False
169
+ )
170
+ elif forward_batch.forward_mode.is_target_verify():
171
+ self.indices_updater_prefill.update(
172
+ forward_batch.req_pool_indices,
173
+ forward_batch.seq_lens,
174
+ forward_batch.seq_lens_sum,
175
+ prefix_lens=None,
176
+ prefill_wrappers=self.prefill_wrappers_verify,
177
+ use_ragged=False,
178
+ encoder_lens=forward_batch.encoder_lens,
179
+ spec_info=forward_batch.spec_info,
180
+ )
181
+ self.forward_metadata = PrefillMetadata(
182
+ self.prefill_wrappers_verify, False, False
183
+ )
149
184
  else:
150
185
  prefix_lens = forward_batch.extend_prefix_lens
151
186
 
@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
165
200
  prefill_wrappers=self.prefill_wrappers_paged,
166
201
  use_ragged=use_ragged,
167
202
  encoder_lens=forward_batch.encoder_lens,
203
+ spec_info=None,
168
204
  )
169
205
  self.forward_metadata = PrefillMetadata(
170
206
  self.prefill_wrappers_paged, use_ragged, extend_no_prefix
@@ -180,37 +216,82 @@ class FlashInferAttnBackend(AttentionBackend):
180
216
  cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
181
217
  ]
182
218
 
219
+ self.cuda_graph_custom_mask = torch.zeros(
220
+ (max_bs * self.max_context_len),
221
+ dtype=torch.uint8,
222
+ device="cuda",
223
+ )
224
+ self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
225
+ self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
226
+
183
227
  def init_forward_metadata_capture_cuda_graph(
184
228
  self,
185
229
  bs: int,
230
+ num_tokens: int,
186
231
  req_pool_indices: torch.Tensor,
187
232
  seq_lens: torch.Tensor,
188
- encoder_lens: torch.Tensor = None,
233
+ encoder_lens: Optional[torch.Tensor],
234
+ forward_mode: ForwardMode,
235
+ spec_info: Optional[SpecInfo],
189
236
  ):
190
- decode_wrappers = []
191
- for i in range(self.num_wrappers):
192
- decode_wrappers.append(
193
- BatchDecodeWithPagedKVCacheWrapper(
194
- self.workspace_buffer,
195
- "NHD",
196
- use_cuda_graph=True,
197
- use_tensor_cores=self.decode_use_tensor_cores,
198
- paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
199
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
200
- paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
237
+ if forward_mode.is_decode():
238
+ decode_wrappers = []
239
+ for i in range(self.num_wrappers):
240
+ decode_wrappers.append(
241
+ BatchDecodeWithPagedKVCacheWrapper(
242
+ self.workspace_buffer,
243
+ "NHD",
244
+ use_cuda_graph=True,
245
+ use_tensor_cores=self.decode_use_tensor_cores,
246
+ paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
247
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
248
+ paged_kv_last_page_len_buffer=self.kv_last_page_len[
249
+ :num_tokens
250
+ ],
251
+ )
201
252
  )
253
+ seq_lens_sum = seq_lens.sum().item()
254
+ self.indices_updater_decode.update(
255
+ req_pool_indices,
256
+ seq_lens,
257
+ seq_lens_sum,
258
+ decode_wrappers=decode_wrappers,
259
+ encoder_lens=encoder_lens,
260
+ spec_info=spec_info,
202
261
  )
203
-
204
- seq_lens_sum = seq_lens.sum().item()
205
- self.indices_updater_decode.update(
206
- req_pool_indices,
207
- seq_lens,
208
- seq_lens_sum,
209
- decode_wrappers=decode_wrappers,
210
- encoder_lens=encoder_lens,
211
- )
212
- self.decode_cuda_graph_metadata[bs] = decode_wrappers
213
- self.forward_metadata = DecodeMetadata(decode_wrappers)
262
+ self.decode_cuda_graph_metadata[bs] = decode_wrappers
263
+ self.forward_metadata = DecodeMetadata(decode_wrappers)
264
+ elif forward_mode.is_target_verify():
265
+ prefill_wrappers = []
266
+ for i in range(self.num_wrappers):
267
+ prefill_wrappers.append(
268
+ BatchPrefillWithPagedKVCacheWrapper(
269
+ self.workspace_buffer,
270
+ "NHD",
271
+ use_cuda_graph=True,
272
+ qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
273
+ paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
274
+ paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
275
+ paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
276
+ custom_mask_buf=self.cuda_graph_custom_mask,
277
+ qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
278
+ )
279
+ )
280
+ seq_lens_sum = seq_lens.sum().item()
281
+ self.indices_updater_prefill.update(
282
+ req_pool_indices,
283
+ seq_lens,
284
+ seq_lens_sum,
285
+ prefix_lens=None,
286
+ prefill_wrappers=prefill_wrappers,
287
+ use_ragged=False,
288
+ encoder_lens=encoder_lens,
289
+ spec_info=spec_info,
290
+ )
291
+ self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
292
+ self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
293
+ else:
294
+ raise ValueError(f"Invalid mode: {forward_mode=}")
214
295
 
215
296
  def init_forward_metadata_replay_cuda_graph(
216
297
  self,
@@ -218,24 +299,41 @@ class FlashInferAttnBackend(AttentionBackend):
218
299
  req_pool_indices: torch.Tensor,
219
300
  seq_lens: torch.Tensor,
220
301
  seq_lens_sum: int,
221
- encoder_lens: torch.Tensor = None,
302
+ encoder_lens: Optional[torch.Tensor],
303
+ forward_mode: ForwardMode,
304
+ spec_info: Optional[SpecInfo],
222
305
  ):
223
- self.indices_updater_decode.update(
224
- req_pool_indices[:bs],
225
- seq_lens[:bs],
226
- seq_lens_sum,
227
- decode_wrappers=self.decode_cuda_graph_metadata[bs],
228
- encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
229
- )
306
+ if forward_mode.is_decode():
307
+ self.indices_updater_decode.update(
308
+ req_pool_indices[:bs],
309
+ seq_lens[:bs],
310
+ seq_lens_sum,
311
+ decode_wrappers=self.decode_cuda_graph_metadata[bs],
312
+ encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
313
+ spec_info=spec_info,
314
+ )
315
+ elif forward_mode.is_target_verify():
316
+ self.indices_updater_prefill.update(
317
+ req_pool_indices[:bs],
318
+ seq_lens[:bs],
319
+ seq_lens_sum,
320
+ prefix_lens=None,
321
+ prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
322
+ use_ragged=False,
323
+ encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
324
+ spec_info=spec_info,
325
+ )
326
+ else:
327
+ raise ValueError("Invalid forward mode")
230
328
 
231
329
  def get_cuda_graph_seq_len_fill_value(self):
232
330
  return 0
233
331
 
234
332
  def forward_extend(
235
333
  self,
236
- q,
237
- k,
238
- v,
334
+ q: torch.Tensor,
335
+ k: torch.Tensor,
336
+ v: torch.Tensor,
239
337
  layer: RadixAttention,
240
338
  forward_batch: ForwardBatch,
241
339
  save_kv_cache=True,
@@ -293,9 +391,9 @@ class FlashInferAttnBackend(AttentionBackend):
293
391
 
294
392
  def forward_decode(
295
393
  self,
296
- q,
297
- k,
298
- v,
394
+ q: torch.Tensor,
395
+ k: torch.Tensor,
396
+ v: torch.Tensor,
299
397
  layer: RadixAttention,
300
398
  forward_batch: ForwardBatch,
301
399
  save_kv_cache=True,
@@ -348,7 +446,6 @@ class FlashInferIndicesUpdaterDecode:
348
446
  self.data_type = model_runner.kv_cache_dtype
349
447
  self.q_data_type = model_runner.dtype
350
448
  self.sliding_window_size = model_runner.sliding_window_size
351
-
352
449
  self.attn_backend = attn_backend
353
450
 
354
451
  # Buffers and wrappers
@@ -371,7 +468,8 @@ class FlashInferIndicesUpdaterDecode:
371
468
  seq_lens: torch.Tensor,
372
469
  seq_lens_sum: int,
373
470
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
374
- encoder_lens: torch.Tensor,
471
+ encoder_lens: Optional[torch.Tensor],
472
+ spec_info: Optional[SpecInfo],
375
473
  ):
376
474
  # Keep the signature for type checking. It will be assigned during runtime.
377
475
  raise NotImplementedError()
@@ -382,7 +480,8 @@ class FlashInferIndicesUpdaterDecode:
382
480
  seq_lens: torch.Tensor,
383
481
  seq_lens_sum: int,
384
482
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
385
- encoder_lens: torch.Tensor,
483
+ encoder_lens: Optional[torch.Tensor],
484
+ spec_info: Optional[SpecInfo],
386
485
  ):
387
486
  decode_wrappers = decode_wrappers or self.decode_wrappers
388
487
  self.call_begin_forward(
@@ -392,6 +491,7 @@ class FlashInferIndicesUpdaterDecode:
392
491
  seq_lens_sum,
393
492
  self.kv_indptr[0],
394
493
  None,
494
+ spec_info,
395
495
  )
396
496
 
397
497
  def update_sliding_window(
@@ -400,7 +500,8 @@ class FlashInferIndicesUpdaterDecode:
400
500
  seq_lens: torch.Tensor,
401
501
  seq_lens_sum: int,
402
502
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
403
- encoder_lens: torch.Tensor,
503
+ encoder_lens: Optional[torch.Tensor],
504
+ spec_info: Optional[SpecInfo],
404
505
  ):
405
506
  for wrapper_id in range(2):
406
507
  if wrapper_id == 0:
@@ -424,6 +525,7 @@ class FlashInferIndicesUpdaterDecode:
424
525
  paged_kernel_lens_sum_tmp,
425
526
  self.kv_indptr[wrapper_id],
426
527
  kv_start_idx_tmp,
528
+ spec_info,
427
529
  )
428
530
 
429
531
  def update_cross_attention(
@@ -432,7 +534,8 @@ class FlashInferIndicesUpdaterDecode:
432
534
  seq_lens: torch.Tensor,
433
535
  seq_lens_sum: int,
434
536
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
435
- encoder_lens: torch.Tensor,
537
+ encoder_lens: Optional[torch.Tensor],
538
+ spec_info: Optional[SpecInfo],
436
539
  ):
437
540
  for wrapper_id in range(2):
438
541
  if wrapper_id == 0:
@@ -452,6 +555,7 @@ class FlashInferIndicesUpdaterDecode:
452
555
  seq_lens_sum,
453
556
  self.kv_indptr[wrapper_id],
454
557
  kv_start_idx,
558
+ spec_info,
455
559
  )
456
560
 
457
561
  def call_begin_forward(
@@ -462,23 +566,30 @@ class FlashInferIndicesUpdaterDecode:
462
566
  paged_kernel_lens_sum: int,
463
567
  kv_indptr: torch.Tensor,
464
568
  kv_start_idx: torch.Tensor,
569
+ spec_info: Optional[SpecInfo],
465
570
  ):
466
- bs = len(req_pool_indices)
467
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
468
- kv_indptr = kv_indptr[: bs + 1]
469
- kv_indices = torch.empty(
470
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
471
- )
472
-
473
- create_flashinfer_kv_indices_triton[(bs,)](
474
- self.req_to_token,
475
- req_pool_indices,
476
- paged_kernel_lens,
477
- kv_indptr,
478
- kv_start_idx,
479
- kv_indices,
480
- self.req_to_token.shape[1],
481
- )
571
+ if spec_info is None:
572
+ bs = len(req_pool_indices)
573
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
574
+ kv_indptr = kv_indptr[: bs + 1]
575
+ kv_indices = torch.empty(
576
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
577
+ )
578
+ create_flashinfer_kv_indices_triton[(bs,)](
579
+ self.req_to_token,
580
+ req_pool_indices,
581
+ paged_kernel_lens,
582
+ kv_indptr,
583
+ kv_start_idx,
584
+ kv_indices,
585
+ self.req_to_token.shape[1],
586
+ )
587
+ else:
588
+ bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
589
+ req_pool_indices,
590
+ paged_kernel_lens,
591
+ self.req_to_token,
592
+ )
482
593
 
483
594
  wrapper.end_forward()
484
595
  wrapper.begin_forward(
@@ -507,7 +618,6 @@ class FlashInferIndicesUpdaterPrefill:
507
618
  self.data_type = model_runner.kv_cache_dtype
508
619
  self.q_data_type = model_runner.dtype
509
620
  self.sliding_window_size = model_runner.sliding_window_size
510
-
511
621
  self.attn_backend = attn_backend
512
622
 
513
623
  # Buffers and wrappers
@@ -534,7 +644,8 @@ class FlashInferIndicesUpdaterPrefill:
534
644
  prefix_lens: torch.Tensor,
535
645
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
536
646
  use_ragged: bool,
537
- encoder_lens: torch.Tensor,
647
+ encoder_lens: Optional[torch.Tensor],
648
+ spec_info: Optional[SpecInfo],
538
649
  ):
539
650
  # Keep the signature for type checking. It will be assigned during runtime.
540
651
  raise NotImplementedError()
@@ -547,7 +658,8 @@ class FlashInferIndicesUpdaterPrefill:
547
658
  prefix_lens: torch.Tensor,
548
659
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
549
660
  use_ragged: bool,
550
- encoder_lens: torch.Tensor,
661
+ encoder_lens: Optional[torch.Tensor],
662
+ spec_info: Optional[SpecInfo],
551
663
  ):
552
664
  if use_ragged:
553
665
  paged_kernel_lens = prefix_lens
@@ -568,6 +680,7 @@ class FlashInferIndicesUpdaterPrefill:
568
680
  self.kv_indptr[0],
569
681
  self.qo_indptr[0],
570
682
  use_ragged,
683
+ spec_info,
571
684
  )
572
685
 
573
686
  def update_sliding_window(
@@ -578,7 +691,8 @@ class FlashInferIndicesUpdaterPrefill:
578
691
  prefix_lens: torch.Tensor,
579
692
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
580
693
  use_ragged: bool,
581
- encoder_lens: torch.Tensor,
694
+ encoder_lens: Optional[torch.Tensor],
695
+ spec_info: Optional[SpecInfo],
582
696
  ):
583
697
  for wrapper_id in range(2):
584
698
  if wrapper_id == 0:
@@ -607,6 +721,7 @@ class FlashInferIndicesUpdaterPrefill:
607
721
  self.kv_indptr[wrapper_id],
608
722
  self.qo_indptr[wrapper_id],
609
723
  use_ragged,
724
+ spec_info,
610
725
  )
611
726
 
612
727
  def update_cross_attention(
@@ -617,7 +732,8 @@ class FlashInferIndicesUpdaterPrefill:
617
732
  prefix_lens: torch.Tensor,
618
733
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
619
734
  use_ragged: bool,
620
- encoder_lens: torch.Tensor,
735
+ encoder_lens: Optional[torch.Tensor],
736
+ spec_info: Optional[SpecInfo],
621
737
  ):
622
738
  for wrapper_id in range(2):
623
739
  if wrapper_id == 0:
@@ -643,6 +759,7 @@ class FlashInferIndicesUpdaterPrefill:
643
759
  self.kv_indptr[wrapper_id],
644
760
  self.qo_indptr[wrapper_id],
645
761
  use_ragged,
762
+ spec_info,
646
763
  )
647
764
 
648
765
  def call_begin_forward(
@@ -658,25 +775,37 @@ class FlashInferIndicesUpdaterPrefill:
658
775
  kv_indptr: torch.Tensor,
659
776
  qo_indptr: torch.Tensor,
660
777
  use_ragged: bool,
778
+ spec_info: Optional[SpecInfo],
661
779
  ):
662
780
  bs = len(req_pool_indices)
663
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
664
- kv_indptr = kv_indptr[: bs + 1]
665
- kv_indices = torch.empty(
666
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
667
- )
668
- create_flashinfer_kv_indices_triton[(bs,)](
669
- self.req_to_token,
670
- req_pool_indices,
671
- paged_kernel_lens,
672
- kv_indptr,
673
- kv_start_idx,
674
- kv_indices,
675
- self.req_to_token.shape[1],
676
- )
781
+ if spec_info is None:
782
+ # Normal extend
783
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
784
+ kv_indptr = kv_indptr[: bs + 1]
785
+ kv_indices = torch.empty(
786
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
787
+ )
788
+ create_flashinfer_kv_indices_triton[(bs,)](
789
+ self.req_to_token,
790
+ req_pool_indices,
791
+ paged_kernel_lens,
792
+ kv_indptr,
793
+ kv_start_idx,
794
+ kv_indices,
795
+ self.req_to_token.shape[1],
796
+ )
677
797
 
678
- qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
679
- qo_indptr = qo_indptr[: bs + 1]
798
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
799
+ qo_indptr = qo_indptr[: bs + 1]
800
+ custom_mask = None
801
+ else:
802
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
803
+ spec_info.generate_attn_arg_prefill(
804
+ req_pool_indices,
805
+ paged_kernel_lens,
806
+ self.req_to_token,
807
+ )
808
+ )
680
809
 
681
810
  # extend part
682
811
  if use_ragged:
@@ -702,6 +831,7 @@ class FlashInferIndicesUpdaterPrefill:
702
831
  self.head_dim,
703
832
  1,
704
833
  q_data_type=self.q_data_type,
834
+ custom_mask=custom_mask,
705
835
  )
706
836
 
707
837