sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +22 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +215 -83
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +170 -215
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +3 -3
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +33 -3
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +68 -28
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +27 -21
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +129 -77
- sglang/srt/model_executor/forward_batch_info.py +51 -21
- sglang/srt/model_executor/model_runner.py +72 -64
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +109 -29
- sglang/srt/models/llama.py +9 -2
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +22 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +20 -13
- sglang/srt/server_args.py +120 -58
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +626 -0
- sglang/srt/speculative/eagle_worker.py +184 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +47 -7
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import torch
|
6
|
-
import torch.nn as nn
|
7
6
|
|
8
7
|
from sglang.srt.layers.attention import AttentionBackend
|
9
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
52
51
|
|
53
52
|
self.forward_metadata = None
|
54
53
|
|
55
|
-
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
56
|
-
|
57
54
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
58
55
|
"""Init auxiliary variables for triton attention backend."""
|
59
56
|
|
@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
115
112
|
ds_req_to_token,
|
116
113
|
)
|
117
114
|
|
118
|
-
def init_cuda_graph_state(self, max_bs: int):
|
119
|
-
# TODO(Andy): Support CUDA graph for double sparse attention
|
120
|
-
raise ValueError(
|
121
|
-
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
122
|
-
)
|
123
|
-
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
124
|
-
|
125
|
-
self.cuda_graph_start_loc = torch.zeros(
|
126
|
-
(max_bs,), dtype=torch.int32, device="cuda"
|
127
|
-
)
|
128
|
-
self.cuda_graph_attn_logits = torch.empty(
|
129
|
-
(
|
130
|
-
self.num_head,
|
131
|
-
self.cuda_graph_max_total_num_tokens,
|
132
|
-
),
|
133
|
-
dtype=self.reduce_dtype,
|
134
|
-
device="cuda",
|
135
|
-
)
|
136
|
-
|
137
|
-
def init_forward_metadata_capture_cuda_graph(
|
138
|
-
self,
|
139
|
-
bs: int,
|
140
|
-
req_pool_indices: torch.Tensor,
|
141
|
-
seq_lens: torch.Tensor,
|
142
|
-
encoder_lens=None,
|
143
|
-
):
|
144
|
-
# NOTE: encoder_lens expected to be zeros or None
|
145
|
-
self.forward_metadata = (
|
146
|
-
self.cuda_graph_start_loc,
|
147
|
-
self.cuda_graph_attn_logits,
|
148
|
-
self.cuda_graph_max_seq_len,
|
149
|
-
None,
|
150
|
-
)
|
151
|
-
|
152
|
-
def init_forward_metadata_replay_cuda_graph(
|
153
|
-
self,
|
154
|
-
bs: int,
|
155
|
-
req_pool_indices: torch.Tensor,
|
156
|
-
seq_lens: torch.Tensor,
|
157
|
-
seq_lens_sum: int,
|
158
|
-
encoder_lens=None,
|
159
|
-
):
|
160
|
-
# NOTE: encoder_lens expected to be zeros or None
|
161
|
-
self.cuda_graph_start_loc.zero_()
|
162
|
-
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
163
|
-
|
164
|
-
def get_cuda_graph_seq_len_fill_value(self):
|
165
|
-
return 1
|
166
|
-
|
167
115
|
def forward_extend(
|
168
116
|
self,
|
169
117
|
q,
|
@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|
10
10
|
import os
|
11
11
|
from dataclasses import dataclass
|
12
12
|
from enum import Enum, auto
|
13
|
-
from typing import TYPE_CHECKING, List, Union
|
13
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
14
14
|
|
15
15
|
import torch
|
16
16
|
import triton
|
@@ -18,12 +18,13 @@ import triton.language as tl
|
|
18
18
|
|
19
19
|
from sglang.global_config import global_config
|
20
20
|
from sglang.srt.layers.attention import AttentionBackend
|
21
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
21
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
22
22
|
from sglang.srt.utils import is_flashinfer_available
|
23
23
|
|
24
24
|
if TYPE_CHECKING:
|
25
25
|
from sglang.srt.layers.radix_attention import RadixAttention
|
26
26
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
27
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
27
28
|
|
28
29
|
if is_flashinfer_available():
|
29
30
|
from flashinfer import (
|
@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
113
114
|
# Two wrappers: one for sliding window attention and one for full attention.
|
114
115
|
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
115
116
|
self.prefill_wrappers_paged = []
|
117
|
+
self.prefill_wrappers_verify = []
|
116
118
|
self.decode_wrappers = []
|
117
119
|
for _ in range(self.num_wrappers):
|
118
120
|
self.prefill_wrappers_paged.append(
|
119
121
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
120
122
|
)
|
123
|
+
self.prefill_wrappers_verify.append(
|
124
|
+
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
125
|
+
)
|
121
126
|
self.decode_wrappers.append(
|
122
127
|
BatchDecodeWithPagedKVCacheWrapper(
|
123
128
|
self.workspace_buffer,
|
@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
135
140
|
# Other metadata
|
136
141
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
137
142
|
self.decode_cuda_graph_metadata = {}
|
143
|
+
self.prefill_cuda_graph_metadata = {}
|
138
144
|
|
139
145
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
140
146
|
if forward_batch.forward_mode.is_decode():
|
@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
144
150
|
forward_batch.seq_lens_sum,
|
145
151
|
decode_wrappers=self.decode_wrappers,
|
146
152
|
encoder_lens=forward_batch.encoder_lens,
|
153
|
+
spec_info=forward_batch.spec_info,
|
147
154
|
)
|
148
155
|
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
156
|
+
elif forward_batch.forward_mode.is_draft_extend():
|
157
|
+
self.indices_updater_prefill.update(
|
158
|
+
forward_batch.req_pool_indices,
|
159
|
+
forward_batch.seq_lens,
|
160
|
+
forward_batch.seq_lens_sum,
|
161
|
+
prefix_lens=None,
|
162
|
+
prefill_wrappers=self.prefill_wrappers_paged,
|
163
|
+
use_ragged=False,
|
164
|
+
encoder_lens=forward_batch.encoder_lens,
|
165
|
+
spec_info=forward_batch.spec_info,
|
166
|
+
)
|
167
|
+
self.forward_metadata = PrefillMetadata(
|
168
|
+
self.prefill_wrappers_paged, False, False
|
169
|
+
)
|
170
|
+
elif forward_batch.forward_mode.is_target_verify():
|
171
|
+
self.indices_updater_prefill.update(
|
172
|
+
forward_batch.req_pool_indices,
|
173
|
+
forward_batch.seq_lens,
|
174
|
+
forward_batch.seq_lens_sum,
|
175
|
+
prefix_lens=None,
|
176
|
+
prefill_wrappers=self.prefill_wrappers_verify,
|
177
|
+
use_ragged=False,
|
178
|
+
encoder_lens=forward_batch.encoder_lens,
|
179
|
+
spec_info=forward_batch.spec_info,
|
180
|
+
)
|
181
|
+
self.forward_metadata = PrefillMetadata(
|
182
|
+
self.prefill_wrappers_verify, False, False
|
183
|
+
)
|
149
184
|
else:
|
150
185
|
prefix_lens = forward_batch.extend_prefix_lens
|
151
186
|
|
@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
165
200
|
prefill_wrappers=self.prefill_wrappers_paged,
|
166
201
|
use_ragged=use_ragged,
|
167
202
|
encoder_lens=forward_batch.encoder_lens,
|
203
|
+
spec_info=None,
|
168
204
|
)
|
169
205
|
self.forward_metadata = PrefillMetadata(
|
170
206
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
@@ -180,37 +216,82 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
180
216
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
181
217
|
]
|
182
218
|
|
219
|
+
self.cuda_graph_custom_mask = torch.zeros(
|
220
|
+
(max_bs * self.max_context_len),
|
221
|
+
dtype=torch.uint8,
|
222
|
+
device="cuda",
|
223
|
+
)
|
224
|
+
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
225
|
+
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
226
|
+
|
183
227
|
def init_forward_metadata_capture_cuda_graph(
|
184
228
|
self,
|
185
229
|
bs: int,
|
230
|
+
num_tokens: int,
|
186
231
|
req_pool_indices: torch.Tensor,
|
187
232
|
seq_lens: torch.Tensor,
|
188
|
-
encoder_lens: torch.Tensor
|
233
|
+
encoder_lens: Optional[torch.Tensor],
|
234
|
+
forward_mode: ForwardMode,
|
235
|
+
spec_info: Optional[SpecInfo],
|
189
236
|
):
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
237
|
+
if forward_mode.is_decode():
|
238
|
+
decode_wrappers = []
|
239
|
+
for i in range(self.num_wrappers):
|
240
|
+
decode_wrappers.append(
|
241
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
242
|
+
self.workspace_buffer,
|
243
|
+
"NHD",
|
244
|
+
use_cuda_graph=True,
|
245
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
246
|
+
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
247
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
248
|
+
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
249
|
+
:num_tokens
|
250
|
+
],
|
251
|
+
)
|
201
252
|
)
|
253
|
+
seq_lens_sum = seq_lens.sum().item()
|
254
|
+
self.indices_updater_decode.update(
|
255
|
+
req_pool_indices,
|
256
|
+
seq_lens,
|
257
|
+
seq_lens_sum,
|
258
|
+
decode_wrappers=decode_wrappers,
|
259
|
+
encoder_lens=encoder_lens,
|
260
|
+
spec_info=spec_info,
|
202
261
|
)
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
262
|
+
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
263
|
+
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
264
|
+
elif forward_mode.is_target_verify():
|
265
|
+
prefill_wrappers = []
|
266
|
+
for i in range(self.num_wrappers):
|
267
|
+
prefill_wrappers.append(
|
268
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
269
|
+
self.workspace_buffer,
|
270
|
+
"NHD",
|
271
|
+
use_cuda_graph=True,
|
272
|
+
qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
|
273
|
+
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
|
274
|
+
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
|
275
|
+
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
276
|
+
custom_mask_buf=self.cuda_graph_custom_mask,
|
277
|
+
qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
|
278
|
+
)
|
279
|
+
)
|
280
|
+
seq_lens_sum = seq_lens.sum().item()
|
281
|
+
self.indices_updater_prefill.update(
|
282
|
+
req_pool_indices,
|
283
|
+
seq_lens,
|
284
|
+
seq_lens_sum,
|
285
|
+
prefix_lens=None,
|
286
|
+
prefill_wrappers=prefill_wrappers,
|
287
|
+
use_ragged=False,
|
288
|
+
encoder_lens=encoder_lens,
|
289
|
+
spec_info=spec_info,
|
290
|
+
)
|
291
|
+
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
|
292
|
+
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
|
293
|
+
else:
|
294
|
+
raise ValueError(f"Invalid mode: {forward_mode=}")
|
214
295
|
|
215
296
|
def init_forward_metadata_replay_cuda_graph(
|
216
297
|
self,
|
@@ -218,24 +299,41 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
218
299
|
req_pool_indices: torch.Tensor,
|
219
300
|
seq_lens: torch.Tensor,
|
220
301
|
seq_lens_sum: int,
|
221
|
-
encoder_lens: torch.Tensor
|
302
|
+
encoder_lens: Optional[torch.Tensor],
|
303
|
+
forward_mode: ForwardMode,
|
304
|
+
spec_info: Optional[SpecInfo],
|
222
305
|
):
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
306
|
+
if forward_mode.is_decode():
|
307
|
+
self.indices_updater_decode.update(
|
308
|
+
req_pool_indices[:bs],
|
309
|
+
seq_lens[:bs],
|
310
|
+
seq_lens_sum,
|
311
|
+
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
312
|
+
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
313
|
+
spec_info=spec_info,
|
314
|
+
)
|
315
|
+
elif forward_mode.is_target_verify():
|
316
|
+
self.indices_updater_prefill.update(
|
317
|
+
req_pool_indices[:bs],
|
318
|
+
seq_lens[:bs],
|
319
|
+
seq_lens_sum,
|
320
|
+
prefix_lens=None,
|
321
|
+
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
322
|
+
use_ragged=False,
|
323
|
+
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
324
|
+
spec_info=spec_info,
|
325
|
+
)
|
326
|
+
else:
|
327
|
+
raise ValueError("Invalid forward mode")
|
230
328
|
|
231
329
|
def get_cuda_graph_seq_len_fill_value(self):
|
232
330
|
return 0
|
233
331
|
|
234
332
|
def forward_extend(
|
235
333
|
self,
|
236
|
-
q,
|
237
|
-
k,
|
238
|
-
v,
|
334
|
+
q: torch.Tensor,
|
335
|
+
k: torch.Tensor,
|
336
|
+
v: torch.Tensor,
|
239
337
|
layer: RadixAttention,
|
240
338
|
forward_batch: ForwardBatch,
|
241
339
|
save_kv_cache=True,
|
@@ -249,6 +347,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
249
347
|
else forward_batch.encoder_out_cache_loc
|
250
348
|
)
|
251
349
|
|
350
|
+
logits_soft_cap = layer.logit_cap
|
351
|
+
|
252
352
|
if not self.forward_metadata.use_ragged:
|
253
353
|
if k is not None:
|
254
354
|
assert v is not None
|
@@ -261,7 +361,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
261
361
|
causal=not layer.is_cross_attention,
|
262
362
|
sm_scale=layer.scaling,
|
263
363
|
window_left=layer.sliding_window_size,
|
264
|
-
logits_soft_cap=
|
364
|
+
logits_soft_cap=logits_soft_cap,
|
265
365
|
)
|
266
366
|
else:
|
267
367
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
@@ -270,7 +370,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
270
370
|
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
271
371
|
causal=True,
|
272
372
|
sm_scale=layer.scaling,
|
273
|
-
logits_soft_cap=
|
373
|
+
logits_soft_cap=logits_soft_cap,
|
274
374
|
)
|
275
375
|
|
276
376
|
if self.forward_metadata.extend_no_prefix:
|
@@ -293,9 +393,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
293
393
|
|
294
394
|
def forward_decode(
|
295
395
|
self,
|
296
|
-
q,
|
297
|
-
k,
|
298
|
-
v,
|
396
|
+
q: torch.Tensor,
|
397
|
+
k: torch.Tensor,
|
398
|
+
v: torch.Tensor,
|
299
399
|
layer: RadixAttention,
|
300
400
|
forward_batch: ForwardBatch,
|
301
401
|
save_kv_cache=True,
|
@@ -348,7 +448,6 @@ class FlashInferIndicesUpdaterDecode:
|
|
348
448
|
self.data_type = model_runner.kv_cache_dtype
|
349
449
|
self.q_data_type = model_runner.dtype
|
350
450
|
self.sliding_window_size = model_runner.sliding_window_size
|
351
|
-
|
352
451
|
self.attn_backend = attn_backend
|
353
452
|
|
354
453
|
# Buffers and wrappers
|
@@ -371,7 +470,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
371
470
|
seq_lens: torch.Tensor,
|
372
471
|
seq_lens_sum: int,
|
373
472
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
374
|
-
encoder_lens: torch.Tensor,
|
473
|
+
encoder_lens: Optional[torch.Tensor],
|
474
|
+
spec_info: Optional[SpecInfo],
|
375
475
|
):
|
376
476
|
# Keep the signature for type checking. It will be assigned during runtime.
|
377
477
|
raise NotImplementedError()
|
@@ -382,7 +482,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
382
482
|
seq_lens: torch.Tensor,
|
383
483
|
seq_lens_sum: int,
|
384
484
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
385
|
-
encoder_lens: torch.Tensor,
|
485
|
+
encoder_lens: Optional[torch.Tensor],
|
486
|
+
spec_info: Optional[SpecInfo],
|
386
487
|
):
|
387
488
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
388
489
|
self.call_begin_forward(
|
@@ -392,6 +493,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
392
493
|
seq_lens_sum,
|
393
494
|
self.kv_indptr[0],
|
394
495
|
None,
|
496
|
+
spec_info,
|
395
497
|
)
|
396
498
|
|
397
499
|
def update_sliding_window(
|
@@ -400,7 +502,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
400
502
|
seq_lens: torch.Tensor,
|
401
503
|
seq_lens_sum: int,
|
402
504
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
403
|
-
encoder_lens: torch.Tensor,
|
505
|
+
encoder_lens: Optional[torch.Tensor],
|
506
|
+
spec_info: Optional[SpecInfo],
|
404
507
|
):
|
405
508
|
for wrapper_id in range(2):
|
406
509
|
if wrapper_id == 0:
|
@@ -424,6 +527,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
424
527
|
paged_kernel_lens_sum_tmp,
|
425
528
|
self.kv_indptr[wrapper_id],
|
426
529
|
kv_start_idx_tmp,
|
530
|
+
spec_info,
|
427
531
|
)
|
428
532
|
|
429
533
|
def update_cross_attention(
|
@@ -432,7 +536,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
432
536
|
seq_lens: torch.Tensor,
|
433
537
|
seq_lens_sum: int,
|
434
538
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
435
|
-
encoder_lens: torch.Tensor,
|
539
|
+
encoder_lens: Optional[torch.Tensor],
|
540
|
+
spec_info: Optional[SpecInfo],
|
436
541
|
):
|
437
542
|
for wrapper_id in range(2):
|
438
543
|
if wrapper_id == 0:
|
@@ -452,6 +557,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
452
557
|
seq_lens_sum,
|
453
558
|
self.kv_indptr[wrapper_id],
|
454
559
|
kv_start_idx,
|
560
|
+
spec_info,
|
455
561
|
)
|
456
562
|
|
457
563
|
def call_begin_forward(
|
@@ -462,23 +568,30 @@ class FlashInferIndicesUpdaterDecode:
|
|
462
568
|
paged_kernel_lens_sum: int,
|
463
569
|
kv_indptr: torch.Tensor,
|
464
570
|
kv_start_idx: torch.Tensor,
|
571
|
+
spec_info: Optional[SpecInfo],
|
465
572
|
):
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
573
|
+
if spec_info is None:
|
574
|
+
bs = len(req_pool_indices)
|
575
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
576
|
+
kv_indptr = kv_indptr[: bs + 1]
|
577
|
+
kv_indices = torch.empty(
|
578
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
579
|
+
)
|
580
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
581
|
+
self.req_to_token,
|
582
|
+
req_pool_indices,
|
583
|
+
paged_kernel_lens,
|
584
|
+
kv_indptr,
|
585
|
+
kv_start_idx,
|
586
|
+
kv_indices,
|
587
|
+
self.req_to_token.shape[1],
|
588
|
+
)
|
589
|
+
else:
|
590
|
+
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
|
591
|
+
req_pool_indices,
|
592
|
+
paged_kernel_lens,
|
593
|
+
self.req_to_token,
|
594
|
+
)
|
482
595
|
|
483
596
|
wrapper.end_forward()
|
484
597
|
wrapper.begin_forward(
|
@@ -507,7 +620,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
507
620
|
self.data_type = model_runner.kv_cache_dtype
|
508
621
|
self.q_data_type = model_runner.dtype
|
509
622
|
self.sliding_window_size = model_runner.sliding_window_size
|
510
|
-
|
511
623
|
self.attn_backend = attn_backend
|
512
624
|
|
513
625
|
# Buffers and wrappers
|
@@ -534,7 +646,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
534
646
|
prefix_lens: torch.Tensor,
|
535
647
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
536
648
|
use_ragged: bool,
|
537
|
-
encoder_lens: torch.Tensor,
|
649
|
+
encoder_lens: Optional[torch.Tensor],
|
650
|
+
spec_info: Optional[SpecInfo],
|
538
651
|
):
|
539
652
|
# Keep the signature for type checking. It will be assigned during runtime.
|
540
653
|
raise NotImplementedError()
|
@@ -547,7 +660,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
547
660
|
prefix_lens: torch.Tensor,
|
548
661
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
549
662
|
use_ragged: bool,
|
550
|
-
encoder_lens: torch.Tensor,
|
663
|
+
encoder_lens: Optional[torch.Tensor],
|
664
|
+
spec_info: Optional[SpecInfo],
|
551
665
|
):
|
552
666
|
if use_ragged:
|
553
667
|
paged_kernel_lens = prefix_lens
|
@@ -568,6 +682,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
568
682
|
self.kv_indptr[0],
|
569
683
|
self.qo_indptr[0],
|
570
684
|
use_ragged,
|
685
|
+
spec_info,
|
571
686
|
)
|
572
687
|
|
573
688
|
def update_sliding_window(
|
@@ -578,7 +693,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
578
693
|
prefix_lens: torch.Tensor,
|
579
694
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
580
695
|
use_ragged: bool,
|
581
|
-
encoder_lens: torch.Tensor,
|
696
|
+
encoder_lens: Optional[torch.Tensor],
|
697
|
+
spec_info: Optional[SpecInfo],
|
582
698
|
):
|
583
699
|
for wrapper_id in range(2):
|
584
700
|
if wrapper_id == 0:
|
@@ -607,6 +723,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
607
723
|
self.kv_indptr[wrapper_id],
|
608
724
|
self.qo_indptr[wrapper_id],
|
609
725
|
use_ragged,
|
726
|
+
spec_info,
|
610
727
|
)
|
611
728
|
|
612
729
|
def update_cross_attention(
|
@@ -617,7 +734,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
617
734
|
prefix_lens: torch.Tensor,
|
618
735
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
619
736
|
use_ragged: bool,
|
620
|
-
encoder_lens: torch.Tensor,
|
737
|
+
encoder_lens: Optional[torch.Tensor],
|
738
|
+
spec_info: Optional[SpecInfo],
|
621
739
|
):
|
622
740
|
for wrapper_id in range(2):
|
623
741
|
if wrapper_id == 0:
|
@@ -643,6 +761,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
643
761
|
self.kv_indptr[wrapper_id],
|
644
762
|
self.qo_indptr[wrapper_id],
|
645
763
|
use_ragged,
|
764
|
+
spec_info,
|
646
765
|
)
|
647
766
|
|
648
767
|
def call_begin_forward(
|
@@ -658,25 +777,37 @@ class FlashInferIndicesUpdaterPrefill:
|
|
658
777
|
kv_indptr: torch.Tensor,
|
659
778
|
qo_indptr: torch.Tensor,
|
660
779
|
use_ragged: bool,
|
780
|
+
spec_info: Optional[SpecInfo],
|
661
781
|
):
|
662
782
|
bs = len(req_pool_indices)
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
783
|
+
if spec_info is None:
|
784
|
+
# Normal extend
|
785
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
786
|
+
kv_indptr = kv_indptr[: bs + 1]
|
787
|
+
kv_indices = torch.empty(
|
788
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
789
|
+
)
|
790
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
791
|
+
self.req_to_token,
|
792
|
+
req_pool_indices,
|
793
|
+
paged_kernel_lens,
|
794
|
+
kv_indptr,
|
795
|
+
kv_start_idx,
|
796
|
+
kv_indices,
|
797
|
+
self.req_to_token.shape[1],
|
798
|
+
)
|
677
799
|
|
678
|
-
|
679
|
-
|
800
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
801
|
+
qo_indptr = qo_indptr[: bs + 1]
|
802
|
+
custom_mask = None
|
803
|
+
else:
|
804
|
+
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
805
|
+
spec_info.generate_attn_arg_prefill(
|
806
|
+
req_pool_indices,
|
807
|
+
paged_kernel_lens,
|
808
|
+
self.req_to_token,
|
809
|
+
)
|
810
|
+
)
|
680
811
|
|
681
812
|
# extend part
|
682
813
|
if use_ragged:
|
@@ -702,6 +833,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
702
833
|
self.head_dim,
|
703
834
|
1,
|
704
835
|
q_data_type=self.q_data_type,
|
836
|
+
custom_mask=custom_mask,
|
705
837
|
)
|
706
838
|
|
707
839
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING
|
3
|
+
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import torch
|
6
6
|
from torch.nn.functional import scaled_dot_product_attention
|
@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
23
23
|
"""Init the metadata for a forward pass."""
|
24
24
|
pass
|
25
25
|
|
26
|
-
def init_cuda_graph_state(self, max_bs: int):
|
27
|
-
# TODO: Support CUDA graph
|
28
|
-
raise ValueError(
|
29
|
-
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
30
|
-
)
|
31
|
-
|
32
|
-
def init_forward_metadata_capture_cuda_graph(
|
33
|
-
self,
|
34
|
-
bs: int,
|
35
|
-
req_pool_indices: torch.Tensor,
|
36
|
-
seq_lens: torch.Tensor,
|
37
|
-
encoder_lens: Optional[torch.Tensor] = None,
|
38
|
-
):
|
39
|
-
# TODO: Support CUDA graph
|
40
|
-
raise ValueError(
|
41
|
-
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
42
|
-
)
|
43
|
-
|
44
|
-
def init_forward_metadata_replay_cuda_graph(
|
45
|
-
self,
|
46
|
-
bs: int,
|
47
|
-
req_pool_indices: torch.Tensor,
|
48
|
-
seq_lens: torch.Tensor,
|
49
|
-
seq_lens_sum: int,
|
50
|
-
encoder_lens: Optional[torch.Tensor] = None,
|
51
|
-
):
|
52
|
-
# TODO: Support CUDA graph
|
53
|
-
raise ValueError(
|
54
|
-
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
55
|
-
)
|
56
|
-
|
57
|
-
def get_cuda_graph_seq_len_fill_value(self):
|
58
|
-
# TODO: Support CUDA graph
|
59
|
-
raise ValueError(
|
60
|
-
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
61
|
-
)
|
62
|
-
|
63
26
|
def _run_sdpa_forward_extend(
|
64
27
|
self,
|
65
28
|
query: torch.Tensor,
|