sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  import torch
6
- import torch.nn as nn
7
6
 
8
7
  from sglang.srt.layers.attention import AttentionBackend
9
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
52
51
 
53
52
  self.forward_metadata = None
54
53
 
55
- self.cuda_graph_max_seq_len = model_runner.model_config.context_len
56
-
57
54
  def init_forward_metadata(self, forward_batch: ForwardBatch):
58
55
  """Init auxiliary variables for triton attention backend."""
59
56
 
@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
115
112
  ds_req_to_token,
116
113
  )
117
114
 
118
- def init_cuda_graph_state(self, max_bs: int):
119
- # TODO(Andy): Support CUDA graph for double sparse attention
120
- raise ValueError(
121
- "Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
122
- )
123
- self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
124
-
125
- self.cuda_graph_start_loc = torch.zeros(
126
- (max_bs,), dtype=torch.int32, device="cuda"
127
- )
128
- self.cuda_graph_attn_logits = torch.empty(
129
- (
130
- self.num_head,
131
- self.cuda_graph_max_total_num_tokens,
132
- ),
133
- dtype=self.reduce_dtype,
134
- device="cuda",
135
- )
136
-
137
- def init_forward_metadata_capture_cuda_graph(
138
- self,
139
- bs: int,
140
- req_pool_indices: torch.Tensor,
141
- seq_lens: torch.Tensor,
142
- encoder_lens=None,
143
- ):
144
- # NOTE: encoder_lens expected to be zeros or None
145
- self.forward_metadata = (
146
- self.cuda_graph_start_loc,
147
- self.cuda_graph_attn_logits,
148
- self.cuda_graph_max_seq_len,
149
- None,
150
- )
151
-
152
- def init_forward_metadata_replay_cuda_graph(
153
- self,
154
- bs: int,
155
- req_pool_indices: torch.Tensor,
156
- seq_lens: torch.Tensor,
157
- seq_lens_sum: int,
158
- encoder_lens=None,
159
- ):
160
- # NOTE: encoder_lens expected to be zeros or None
161
- self.cuda_graph_start_loc.zero_()
162
- self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
163
-
164
- def get_cuda_graph_seq_len_fill_value(self):
165
- return 1
166
-
167
115
  def forward_extend(
168
116
  self,
169
117
  q,
@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
10
10
  import os
11
11
  from dataclasses import dataclass
12
12
  from enum import Enum, auto
13
- from typing import TYPE_CHECKING, List, Union
13
+ from typing import TYPE_CHECKING, List, Optional, Union
14
14
 
15
15
  import torch
16
16
  import triton
@@ -18,12 +18,13 @@ import triton.language as tl
18
18
 
19
19
  from sglang.global_config import global_config
20
20
  from sglang.srt.layers.attention import AttentionBackend
21
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
22
22
  from sglang.srt.utils import is_flashinfer_available
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  from sglang.srt.layers.radix_attention import RadixAttention
26
26
  from sglang.srt.model_executor.model_runner import ModelRunner
27
+ from sglang.srt.speculative.spec_info import SpecInfo
27
28
 
28
29
  if is_flashinfer_available():
29
30
  from flashinfer import (
@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
113
114
  # Two wrappers: one for sliding window attention and one for full attention.
114
115
  # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
115
116
  self.prefill_wrappers_paged = []
117
+ self.prefill_wrappers_verify = []
116
118
  self.decode_wrappers = []
117
119
  for _ in range(self.num_wrappers):
118
120
  self.prefill_wrappers_paged.append(
119
121
  BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
120
122
  )
123
+ self.prefill_wrappers_verify.append(
124
+ BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
125
+ )
121
126
  self.decode_wrappers.append(
122
127
  BatchDecodeWithPagedKVCacheWrapper(
123
128
  self.workspace_buffer,
@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
135
140
  # Other metadata
136
141
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
137
142
  self.decode_cuda_graph_metadata = {}
143
+ self.prefill_cuda_graph_metadata = {}
138
144
 
139
145
  def init_forward_metadata(self, forward_batch: ForwardBatch):
140
146
  if forward_batch.forward_mode.is_decode():
@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
144
150
  forward_batch.seq_lens_sum,
145
151
  decode_wrappers=self.decode_wrappers,
146
152
  encoder_lens=forward_batch.encoder_lens,
153
+ spec_info=forward_batch.spec_info,
147
154
  )
148
155
  self.forward_metadata = DecodeMetadata(self.decode_wrappers)
156
+ elif forward_batch.forward_mode.is_draft_extend():
157
+ self.indices_updater_prefill.update(
158
+ forward_batch.req_pool_indices,
159
+ forward_batch.seq_lens,
160
+ forward_batch.seq_lens_sum,
161
+ prefix_lens=None,
162
+ prefill_wrappers=self.prefill_wrappers_paged,
163
+ use_ragged=False,
164
+ encoder_lens=forward_batch.encoder_lens,
165
+ spec_info=forward_batch.spec_info,
166
+ )
167
+ self.forward_metadata = PrefillMetadata(
168
+ self.prefill_wrappers_paged, False, False
169
+ )
170
+ elif forward_batch.forward_mode.is_target_verify():
171
+ self.indices_updater_prefill.update(
172
+ forward_batch.req_pool_indices,
173
+ forward_batch.seq_lens,
174
+ forward_batch.seq_lens_sum,
175
+ prefix_lens=None,
176
+ prefill_wrappers=self.prefill_wrappers_verify,
177
+ use_ragged=False,
178
+ encoder_lens=forward_batch.encoder_lens,
179
+ spec_info=forward_batch.spec_info,
180
+ )
181
+ self.forward_metadata = PrefillMetadata(
182
+ self.prefill_wrappers_verify, False, False
183
+ )
149
184
  else:
150
185
  prefix_lens = forward_batch.extend_prefix_lens
151
186
 
@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
165
200
  prefill_wrappers=self.prefill_wrappers_paged,
166
201
  use_ragged=use_ragged,
167
202
  encoder_lens=forward_batch.encoder_lens,
203
+ spec_info=None,
168
204
  )
169
205
  self.forward_metadata = PrefillMetadata(
170
206
  self.prefill_wrappers_paged, use_ragged, extend_no_prefix
@@ -180,37 +216,82 @@ class FlashInferAttnBackend(AttentionBackend):
180
216
  cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
181
217
  ]
182
218
 
219
+ self.cuda_graph_custom_mask = torch.zeros(
220
+ (max_bs * self.max_context_len),
221
+ dtype=torch.uint8,
222
+ device="cuda",
223
+ )
224
+ self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
225
+ self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
226
+
183
227
  def init_forward_metadata_capture_cuda_graph(
184
228
  self,
185
229
  bs: int,
230
+ num_tokens: int,
186
231
  req_pool_indices: torch.Tensor,
187
232
  seq_lens: torch.Tensor,
188
- encoder_lens: torch.Tensor = None,
233
+ encoder_lens: Optional[torch.Tensor],
234
+ forward_mode: ForwardMode,
235
+ spec_info: Optional[SpecInfo],
189
236
  ):
190
- decode_wrappers = []
191
- for i in range(self.num_wrappers):
192
- decode_wrappers.append(
193
- BatchDecodeWithPagedKVCacheWrapper(
194
- self.workspace_buffer,
195
- "NHD",
196
- use_cuda_graph=True,
197
- use_tensor_cores=self.decode_use_tensor_cores,
198
- paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
199
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
200
- paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
237
+ if forward_mode.is_decode():
238
+ decode_wrappers = []
239
+ for i in range(self.num_wrappers):
240
+ decode_wrappers.append(
241
+ BatchDecodeWithPagedKVCacheWrapper(
242
+ self.workspace_buffer,
243
+ "NHD",
244
+ use_cuda_graph=True,
245
+ use_tensor_cores=self.decode_use_tensor_cores,
246
+ paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
247
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
248
+ paged_kv_last_page_len_buffer=self.kv_last_page_len[
249
+ :num_tokens
250
+ ],
251
+ )
201
252
  )
253
+ seq_lens_sum = seq_lens.sum().item()
254
+ self.indices_updater_decode.update(
255
+ req_pool_indices,
256
+ seq_lens,
257
+ seq_lens_sum,
258
+ decode_wrappers=decode_wrappers,
259
+ encoder_lens=encoder_lens,
260
+ spec_info=spec_info,
202
261
  )
203
-
204
- seq_lens_sum = seq_lens.sum().item()
205
- self.indices_updater_decode.update(
206
- req_pool_indices,
207
- seq_lens,
208
- seq_lens_sum,
209
- decode_wrappers=decode_wrappers,
210
- encoder_lens=encoder_lens,
211
- )
212
- self.decode_cuda_graph_metadata[bs] = decode_wrappers
213
- self.forward_metadata = DecodeMetadata(decode_wrappers)
262
+ self.decode_cuda_graph_metadata[bs] = decode_wrappers
263
+ self.forward_metadata = DecodeMetadata(decode_wrappers)
264
+ elif forward_mode.is_target_verify():
265
+ prefill_wrappers = []
266
+ for i in range(self.num_wrappers):
267
+ prefill_wrappers.append(
268
+ BatchPrefillWithPagedKVCacheWrapper(
269
+ self.workspace_buffer,
270
+ "NHD",
271
+ use_cuda_graph=True,
272
+ qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
273
+ paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
274
+ paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
275
+ paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
276
+ custom_mask_buf=self.cuda_graph_custom_mask,
277
+ qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
278
+ )
279
+ )
280
+ seq_lens_sum = seq_lens.sum().item()
281
+ self.indices_updater_prefill.update(
282
+ req_pool_indices,
283
+ seq_lens,
284
+ seq_lens_sum,
285
+ prefix_lens=None,
286
+ prefill_wrappers=prefill_wrappers,
287
+ use_ragged=False,
288
+ encoder_lens=encoder_lens,
289
+ spec_info=spec_info,
290
+ )
291
+ self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
292
+ self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
293
+ else:
294
+ raise ValueError(f"Invalid mode: {forward_mode=}")
214
295
 
215
296
  def init_forward_metadata_replay_cuda_graph(
216
297
  self,
@@ -218,24 +299,41 @@ class FlashInferAttnBackend(AttentionBackend):
218
299
  req_pool_indices: torch.Tensor,
219
300
  seq_lens: torch.Tensor,
220
301
  seq_lens_sum: int,
221
- encoder_lens: torch.Tensor = None,
302
+ encoder_lens: Optional[torch.Tensor],
303
+ forward_mode: ForwardMode,
304
+ spec_info: Optional[SpecInfo],
222
305
  ):
223
- self.indices_updater_decode.update(
224
- req_pool_indices[:bs],
225
- seq_lens[:bs],
226
- seq_lens_sum,
227
- decode_wrappers=self.decode_cuda_graph_metadata[bs],
228
- encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
229
- )
306
+ if forward_mode.is_decode():
307
+ self.indices_updater_decode.update(
308
+ req_pool_indices[:bs],
309
+ seq_lens[:bs],
310
+ seq_lens_sum,
311
+ decode_wrappers=self.decode_cuda_graph_metadata[bs],
312
+ encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
313
+ spec_info=spec_info,
314
+ )
315
+ elif forward_mode.is_target_verify():
316
+ self.indices_updater_prefill.update(
317
+ req_pool_indices[:bs],
318
+ seq_lens[:bs],
319
+ seq_lens_sum,
320
+ prefix_lens=None,
321
+ prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
322
+ use_ragged=False,
323
+ encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
324
+ spec_info=spec_info,
325
+ )
326
+ else:
327
+ raise ValueError("Invalid forward mode")
230
328
 
231
329
  def get_cuda_graph_seq_len_fill_value(self):
232
330
  return 0
233
331
 
234
332
  def forward_extend(
235
333
  self,
236
- q,
237
- k,
238
- v,
334
+ q: torch.Tensor,
335
+ k: torch.Tensor,
336
+ v: torch.Tensor,
239
337
  layer: RadixAttention,
240
338
  forward_batch: ForwardBatch,
241
339
  save_kv_cache=True,
@@ -249,6 +347,8 @@ class FlashInferAttnBackend(AttentionBackend):
249
347
  else forward_batch.encoder_out_cache_loc
250
348
  )
251
349
 
350
+ logits_soft_cap = layer.logit_cap
351
+
252
352
  if not self.forward_metadata.use_ragged:
253
353
  if k is not None:
254
354
  assert v is not None
@@ -261,7 +361,7 @@ class FlashInferAttnBackend(AttentionBackend):
261
361
  causal=not layer.is_cross_attention,
262
362
  sm_scale=layer.scaling,
263
363
  window_left=layer.sliding_window_size,
264
- logits_soft_cap=layer.logit_cap,
364
+ logits_soft_cap=logits_soft_cap,
265
365
  )
266
366
  else:
267
367
  o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -270,7 +370,7 @@ class FlashInferAttnBackend(AttentionBackend):
270
370
  v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
271
371
  causal=True,
272
372
  sm_scale=layer.scaling,
273
- logits_soft_cap=layer.logit_cap,
373
+ logits_soft_cap=logits_soft_cap,
274
374
  )
275
375
 
276
376
  if self.forward_metadata.extend_no_prefix:
@@ -293,9 +393,9 @@ class FlashInferAttnBackend(AttentionBackend):
293
393
 
294
394
  def forward_decode(
295
395
  self,
296
- q,
297
- k,
298
- v,
396
+ q: torch.Tensor,
397
+ k: torch.Tensor,
398
+ v: torch.Tensor,
299
399
  layer: RadixAttention,
300
400
  forward_batch: ForwardBatch,
301
401
  save_kv_cache=True,
@@ -348,7 +448,6 @@ class FlashInferIndicesUpdaterDecode:
348
448
  self.data_type = model_runner.kv_cache_dtype
349
449
  self.q_data_type = model_runner.dtype
350
450
  self.sliding_window_size = model_runner.sliding_window_size
351
-
352
451
  self.attn_backend = attn_backend
353
452
 
354
453
  # Buffers and wrappers
@@ -371,7 +470,8 @@ class FlashInferIndicesUpdaterDecode:
371
470
  seq_lens: torch.Tensor,
372
471
  seq_lens_sum: int,
373
472
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
374
- encoder_lens: torch.Tensor,
473
+ encoder_lens: Optional[torch.Tensor],
474
+ spec_info: Optional[SpecInfo],
375
475
  ):
376
476
  # Keep the signature for type checking. It will be assigned during runtime.
377
477
  raise NotImplementedError()
@@ -382,7 +482,8 @@ class FlashInferIndicesUpdaterDecode:
382
482
  seq_lens: torch.Tensor,
383
483
  seq_lens_sum: int,
384
484
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
385
- encoder_lens: torch.Tensor,
485
+ encoder_lens: Optional[torch.Tensor],
486
+ spec_info: Optional[SpecInfo],
386
487
  ):
387
488
  decode_wrappers = decode_wrappers or self.decode_wrappers
388
489
  self.call_begin_forward(
@@ -392,6 +493,7 @@ class FlashInferIndicesUpdaterDecode:
392
493
  seq_lens_sum,
393
494
  self.kv_indptr[0],
394
495
  None,
496
+ spec_info,
395
497
  )
396
498
 
397
499
  def update_sliding_window(
@@ -400,7 +502,8 @@ class FlashInferIndicesUpdaterDecode:
400
502
  seq_lens: torch.Tensor,
401
503
  seq_lens_sum: int,
402
504
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
403
- encoder_lens: torch.Tensor,
505
+ encoder_lens: Optional[torch.Tensor],
506
+ spec_info: Optional[SpecInfo],
404
507
  ):
405
508
  for wrapper_id in range(2):
406
509
  if wrapper_id == 0:
@@ -424,6 +527,7 @@ class FlashInferIndicesUpdaterDecode:
424
527
  paged_kernel_lens_sum_tmp,
425
528
  self.kv_indptr[wrapper_id],
426
529
  kv_start_idx_tmp,
530
+ spec_info,
427
531
  )
428
532
 
429
533
  def update_cross_attention(
@@ -432,7 +536,8 @@ class FlashInferIndicesUpdaterDecode:
432
536
  seq_lens: torch.Tensor,
433
537
  seq_lens_sum: int,
434
538
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
435
- encoder_lens: torch.Tensor,
539
+ encoder_lens: Optional[torch.Tensor],
540
+ spec_info: Optional[SpecInfo],
436
541
  ):
437
542
  for wrapper_id in range(2):
438
543
  if wrapper_id == 0:
@@ -452,6 +557,7 @@ class FlashInferIndicesUpdaterDecode:
452
557
  seq_lens_sum,
453
558
  self.kv_indptr[wrapper_id],
454
559
  kv_start_idx,
560
+ spec_info,
455
561
  )
456
562
 
457
563
  def call_begin_forward(
@@ -462,23 +568,30 @@ class FlashInferIndicesUpdaterDecode:
462
568
  paged_kernel_lens_sum: int,
463
569
  kv_indptr: torch.Tensor,
464
570
  kv_start_idx: torch.Tensor,
571
+ spec_info: Optional[SpecInfo],
465
572
  ):
466
- bs = len(req_pool_indices)
467
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
468
- kv_indptr = kv_indptr[: bs + 1]
469
- kv_indices = torch.empty(
470
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
471
- )
472
-
473
- create_flashinfer_kv_indices_triton[(bs,)](
474
- self.req_to_token,
475
- req_pool_indices,
476
- paged_kernel_lens,
477
- kv_indptr,
478
- kv_start_idx,
479
- kv_indices,
480
- self.req_to_token.shape[1],
481
- )
573
+ if spec_info is None:
574
+ bs = len(req_pool_indices)
575
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
576
+ kv_indptr = kv_indptr[: bs + 1]
577
+ kv_indices = torch.empty(
578
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
579
+ )
580
+ create_flashinfer_kv_indices_triton[(bs,)](
581
+ self.req_to_token,
582
+ req_pool_indices,
583
+ paged_kernel_lens,
584
+ kv_indptr,
585
+ kv_start_idx,
586
+ kv_indices,
587
+ self.req_to_token.shape[1],
588
+ )
589
+ else:
590
+ bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
591
+ req_pool_indices,
592
+ paged_kernel_lens,
593
+ self.req_to_token,
594
+ )
482
595
 
483
596
  wrapper.end_forward()
484
597
  wrapper.begin_forward(
@@ -507,7 +620,6 @@ class FlashInferIndicesUpdaterPrefill:
507
620
  self.data_type = model_runner.kv_cache_dtype
508
621
  self.q_data_type = model_runner.dtype
509
622
  self.sliding_window_size = model_runner.sliding_window_size
510
-
511
623
  self.attn_backend = attn_backend
512
624
 
513
625
  # Buffers and wrappers
@@ -534,7 +646,8 @@ class FlashInferIndicesUpdaterPrefill:
534
646
  prefix_lens: torch.Tensor,
535
647
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
536
648
  use_ragged: bool,
537
- encoder_lens: torch.Tensor,
649
+ encoder_lens: Optional[torch.Tensor],
650
+ spec_info: Optional[SpecInfo],
538
651
  ):
539
652
  # Keep the signature for type checking. It will be assigned during runtime.
540
653
  raise NotImplementedError()
@@ -547,7 +660,8 @@ class FlashInferIndicesUpdaterPrefill:
547
660
  prefix_lens: torch.Tensor,
548
661
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
549
662
  use_ragged: bool,
550
- encoder_lens: torch.Tensor,
663
+ encoder_lens: Optional[torch.Tensor],
664
+ spec_info: Optional[SpecInfo],
551
665
  ):
552
666
  if use_ragged:
553
667
  paged_kernel_lens = prefix_lens
@@ -568,6 +682,7 @@ class FlashInferIndicesUpdaterPrefill:
568
682
  self.kv_indptr[0],
569
683
  self.qo_indptr[0],
570
684
  use_ragged,
685
+ spec_info,
571
686
  )
572
687
 
573
688
  def update_sliding_window(
@@ -578,7 +693,8 @@ class FlashInferIndicesUpdaterPrefill:
578
693
  prefix_lens: torch.Tensor,
579
694
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
580
695
  use_ragged: bool,
581
- encoder_lens: torch.Tensor,
696
+ encoder_lens: Optional[torch.Tensor],
697
+ spec_info: Optional[SpecInfo],
582
698
  ):
583
699
  for wrapper_id in range(2):
584
700
  if wrapper_id == 0:
@@ -607,6 +723,7 @@ class FlashInferIndicesUpdaterPrefill:
607
723
  self.kv_indptr[wrapper_id],
608
724
  self.qo_indptr[wrapper_id],
609
725
  use_ragged,
726
+ spec_info,
610
727
  )
611
728
 
612
729
  def update_cross_attention(
@@ -617,7 +734,8 @@ class FlashInferIndicesUpdaterPrefill:
617
734
  prefix_lens: torch.Tensor,
618
735
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
619
736
  use_ragged: bool,
620
- encoder_lens: torch.Tensor,
737
+ encoder_lens: Optional[torch.Tensor],
738
+ spec_info: Optional[SpecInfo],
621
739
  ):
622
740
  for wrapper_id in range(2):
623
741
  if wrapper_id == 0:
@@ -643,6 +761,7 @@ class FlashInferIndicesUpdaterPrefill:
643
761
  self.kv_indptr[wrapper_id],
644
762
  self.qo_indptr[wrapper_id],
645
763
  use_ragged,
764
+ spec_info,
646
765
  )
647
766
 
648
767
  def call_begin_forward(
@@ -658,25 +777,37 @@ class FlashInferIndicesUpdaterPrefill:
658
777
  kv_indptr: torch.Tensor,
659
778
  qo_indptr: torch.Tensor,
660
779
  use_ragged: bool,
780
+ spec_info: Optional[SpecInfo],
661
781
  ):
662
782
  bs = len(req_pool_indices)
663
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
664
- kv_indptr = kv_indptr[: bs + 1]
665
- kv_indices = torch.empty(
666
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
667
- )
668
- create_flashinfer_kv_indices_triton[(bs,)](
669
- self.req_to_token,
670
- req_pool_indices,
671
- paged_kernel_lens,
672
- kv_indptr,
673
- kv_start_idx,
674
- kv_indices,
675
- self.req_to_token.shape[1],
676
- )
783
+ if spec_info is None:
784
+ # Normal extend
785
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
786
+ kv_indptr = kv_indptr[: bs + 1]
787
+ kv_indices = torch.empty(
788
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
789
+ )
790
+ create_flashinfer_kv_indices_triton[(bs,)](
791
+ self.req_to_token,
792
+ req_pool_indices,
793
+ paged_kernel_lens,
794
+ kv_indptr,
795
+ kv_start_idx,
796
+ kv_indices,
797
+ self.req_to_token.shape[1],
798
+ )
677
799
 
678
- qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
679
- qo_indptr = qo_indptr[: bs + 1]
800
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
801
+ qo_indptr = qo_indptr[: bs + 1]
802
+ custom_mask = None
803
+ else:
804
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
805
+ spec_info.generate_attn_arg_prefill(
806
+ req_pool_indices,
807
+ paged_kernel_lens,
808
+ self.req_to_token,
809
+ )
810
+ )
680
811
 
681
812
  # extend part
682
813
  if use_ragged:
@@ -702,6 +833,7 @@ class FlashInferIndicesUpdaterPrefill:
702
833
  self.head_dim,
703
834
  1,
704
835
  q_data_type=self.q_data_type,
836
+ custom_mask=custom_mask,
705
837
  )
706
838
 
707
839
 
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Optional
3
+ from typing import TYPE_CHECKING
4
4
 
5
5
  import torch
6
6
  from torch.nn.functional import scaled_dot_product_attention
@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
23
23
  """Init the metadata for a forward pass."""
24
24
  pass
25
25
 
26
- def init_cuda_graph_state(self, max_bs: int):
27
- # TODO: Support CUDA graph
28
- raise ValueError(
29
- "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
30
- )
31
-
32
- def init_forward_metadata_capture_cuda_graph(
33
- self,
34
- bs: int,
35
- req_pool_indices: torch.Tensor,
36
- seq_lens: torch.Tensor,
37
- encoder_lens: Optional[torch.Tensor] = None,
38
- ):
39
- # TODO: Support CUDA graph
40
- raise ValueError(
41
- "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
42
- )
43
-
44
- def init_forward_metadata_replay_cuda_graph(
45
- self,
46
- bs: int,
47
- req_pool_indices: torch.Tensor,
48
- seq_lens: torch.Tensor,
49
- seq_lens_sum: int,
50
- encoder_lens: Optional[torch.Tensor] = None,
51
- ):
52
- # TODO: Support CUDA graph
53
- raise ValueError(
54
- "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
55
- )
56
-
57
- def get_cuda_graph_seq_len_fill_value(self):
58
- # TODO: Support CUDA graph
59
- raise ValueError(
60
- "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
61
- )
62
-
63
26
  def _run_sdpa_forward_extend(
64
27
  self,
65
28
  query: torch.Tensor,