sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
+
from enum import Enum, auto
|
10
11
|
from typing import TYPE_CHECKING
|
11
12
|
|
12
13
|
import torch
|
13
14
|
import torch.nn as nn
|
15
|
+
import triton
|
16
|
+
import triton.language as tl
|
14
17
|
|
15
18
|
from sglang.global_config import global_config
|
16
19
|
from sglang.srt.layers.attention import AttentionBackend
|
17
|
-
from sglang.srt.
|
18
|
-
WrapperDispatch,
|
19
|
-
update_flashinfer_indices,
|
20
|
-
)
|
21
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
20
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
22
21
|
from sglang.srt.utils import is_flashinfer_available
|
23
22
|
|
24
23
|
if TYPE_CHECKING:
|
@@ -34,13 +33,18 @@ if is_flashinfer_available():
|
|
34
33
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
35
34
|
|
36
35
|
|
36
|
+
class WrapperDispatch(Enum):
|
37
|
+
SLIDING_WINDOW = auto()
|
38
|
+
CROSS_ATTENTION = auto()
|
39
|
+
|
40
|
+
|
37
41
|
class FlashInferAttnBackend(AttentionBackend):
|
38
42
|
"""Flashinfer attention kernels."""
|
39
43
|
|
40
44
|
def __init__(self, model_runner: ModelRunner):
|
41
45
|
super().__init__()
|
42
|
-
self.model_runner = model_runner
|
43
46
|
|
47
|
+
# Parse constants
|
44
48
|
if not _grouped_size_compiled_for_decode_kernels(
|
45
49
|
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
46
50
|
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
@@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
48
52
|
self.decode_use_tensor_cores = True
|
49
53
|
else:
|
50
54
|
self.decode_use_tensor_cores = False
|
51
|
-
|
52
|
-
self.workspace_buffer = torch.empty(
|
53
|
-
global_config.flashinfer_workspace_size,
|
54
|
-
dtype=torch.uint8,
|
55
|
-
device="cuda",
|
56
|
-
)
|
55
|
+
self.max_context_len = model_runner.model_config.context_len
|
57
56
|
|
58
57
|
assert not (
|
59
58
|
model_runner.sliding_window_size is not None
|
60
59
|
and model_runner.has_cross_attention
|
61
60
|
), "Sliding window and cross attention are not supported together"
|
62
61
|
|
63
|
-
self.num_wrappers = 1
|
64
|
-
self.dispatch_reason = None
|
65
62
|
if model_runner.sliding_window_size is not None:
|
66
63
|
self.num_wrappers = 2
|
67
64
|
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
68
65
|
elif model_runner.has_cross_attention:
|
69
66
|
self.num_wrappers = 2
|
70
67
|
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
68
|
+
else:
|
69
|
+
self.num_wrappers = 1
|
70
|
+
self.dispatch_reason = None
|
71
|
+
|
72
|
+
# Allocate buffers
|
73
|
+
self.workspace_buffer = torch.empty(
|
74
|
+
global_config.flashinfer_workspace_size,
|
75
|
+
dtype=torch.uint8,
|
76
|
+
device=model_runner.device,
|
77
|
+
)
|
78
|
+
max_bs = model_runner.req_to_token_pool.size
|
79
|
+
self.kv_indptr = [
|
80
|
+
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
81
|
+
for _ in range(self.num_wrappers)
|
82
|
+
]
|
83
|
+
self.kv_last_page_len = torch.ones(
|
84
|
+
(max_bs,), dtype=torch.int32, device=model_runner.device
|
85
|
+
)
|
86
|
+
self.qo_indptr = [
|
87
|
+
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
88
|
+
for _ in range(self.num_wrappers)
|
89
|
+
]
|
71
90
|
|
91
|
+
# Create wrappers
|
72
92
|
# NOTE: we do not use ragged attention when there are multiple wrappers
|
73
93
|
self.prefill_wrapper_ragged = (
|
74
94
|
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
@@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
92
112
|
)
|
93
113
|
)
|
94
114
|
|
115
|
+
# Create indices updater
|
116
|
+
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
117
|
+
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
118
|
+
model_runner, self
|
119
|
+
)
|
120
|
+
|
121
|
+
# Other metadata
|
95
122
|
self.forward_metadata = None
|
96
123
|
self.cuda_graph_metadata = {}
|
97
124
|
|
98
|
-
def _get_wrapper_idx(self, layer: nn.Module):
|
99
|
-
if self.num_wrappers == 1:
|
100
|
-
return 0
|
101
|
-
|
102
|
-
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
103
|
-
return layer.sliding_window_size == -1
|
104
|
-
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
105
|
-
return layer.is_cross_attention
|
106
|
-
|
107
|
-
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
|
108
|
-
|
109
125
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
110
126
|
if forward_batch.forward_mode.is_decode():
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
127
|
+
self.indices_updater_decode.update(
|
128
|
+
forward_batch.req_pool_indices,
|
129
|
+
forward_batch.seq_lens,
|
130
|
+
)
|
131
|
+
self.forward_metadata = (self.decode_wrappers,)
|
115
132
|
else:
|
116
133
|
prefix_lens = forward_batch.extend_prefix_lens
|
117
134
|
|
@@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
123
140
|
):
|
124
141
|
use_ragged = True
|
125
142
|
|
126
|
-
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
127
143
|
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
|
128
144
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
use_ragged=use_ragged,
|
136
|
-
)
|
145
|
+
self.indices_updater_prefill.update(
|
146
|
+
forward_batch.req_pool_indices,
|
147
|
+
forward_batch.seq_lens,
|
148
|
+
prefix_lens,
|
149
|
+
use_ragged,
|
150
|
+
)
|
137
151
|
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
self.decode_wrappers,
|
143
|
-
)
|
152
|
+
self.forward_metadata = (
|
153
|
+
use_ragged,
|
154
|
+
extend_no_prefix,
|
155
|
+
)
|
144
156
|
|
145
157
|
def init_cuda_graph_state(self, max_bs: int):
|
146
|
-
|
147
|
-
(max_bs
|
148
|
-
)
|
149
|
-
self.cuda_graph_kv_indices = torch.zeros(
|
150
|
-
(max_bs * self.model_runner.model_config.context_len,),
|
158
|
+
cuda_graph_kv_indices = torch.zeros(
|
159
|
+
(max_bs * self.max_context_len,),
|
151
160
|
dtype=torch.int32,
|
152
161
|
device="cuda",
|
153
162
|
)
|
154
|
-
self.
|
155
|
-
(
|
156
|
-
)
|
157
|
-
|
158
|
-
# NOTE: the buffers are always in the form of list
|
159
|
-
self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [
|
160
|
-
self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1)
|
161
|
-
]
|
162
|
-
self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [
|
163
|
-
self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
163
|
+
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
|
164
|
+
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
164
165
|
]
|
165
166
|
|
166
167
|
def init_forward_metadata_capture_cuda_graph(
|
167
|
-
self, bs: int, req_pool_indices, seq_lens
|
168
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
168
169
|
):
|
169
170
|
decode_wrappers = []
|
170
171
|
for i in range(self.num_wrappers):
|
@@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
174
175
|
"NHD",
|
175
176
|
use_cuda_graph=True,
|
176
177
|
use_tensor_cores=self.decode_use_tensor_cores,
|
177
|
-
paged_kv_indptr_buffer=self.
|
178
|
+
paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
|
178
179
|
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
179
|
-
paged_kv_last_page_len_buffer=self.
|
180
|
+
paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
|
180
181
|
)
|
181
182
|
)
|
182
183
|
|
183
|
-
|
184
|
-
ForwardMode.DECODE,
|
185
|
-
self.model_runner,
|
186
|
-
req_pool_indices,
|
187
|
-
seq_lens,
|
188
|
-
None,
|
189
|
-
decode_wrappers,
|
190
|
-
)
|
191
|
-
|
184
|
+
self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers)
|
192
185
|
self.cuda_graph_metadata[bs] = decode_wrappers
|
193
|
-
|
194
|
-
self.forward_metadata = (False, False, None, decode_wrappers)
|
186
|
+
self.forward_metadata = (decode_wrappers,)
|
195
187
|
|
196
188
|
def init_forward_metadata_replay_cuda_graph(
|
197
|
-
self, bs: int, req_pool_indices, seq_lens
|
189
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
198
190
|
):
|
199
|
-
|
200
|
-
|
201
|
-
self.model_runner,
|
202
|
-
req_pool_indices[:bs],
|
203
|
-
seq_lens[:bs],
|
204
|
-
None,
|
205
|
-
self.cuda_graph_metadata[bs],
|
191
|
+
self.indices_updater_decode.update(
|
192
|
+
req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs]
|
206
193
|
)
|
207
194
|
|
208
195
|
def get_cuda_graph_seq_len_fill_value(self):
|
@@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
213
200
|
self._get_wrapper_idx(layer)
|
214
201
|
]
|
215
202
|
|
216
|
-
use_ragged, extend_no_prefix
|
203
|
+
use_ragged, extend_no_prefix = self.forward_metadata
|
217
204
|
|
218
205
|
if not use_ragged:
|
219
206
|
if k is not None:
|
@@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
259
246
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
260
247
|
|
261
248
|
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
262
|
-
decode_wrapper = self.forward_metadata[
|
249
|
+
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
263
250
|
|
264
251
|
if k is not None:
|
265
252
|
assert v is not None
|
@@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
275
262
|
)
|
276
263
|
|
277
264
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
265
|
+
|
266
|
+
def _get_wrapper_idx(self, layer: nn.Module):
|
267
|
+
if self.num_wrappers == 1:
|
268
|
+
return 0
|
269
|
+
|
270
|
+
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
271
|
+
return layer.sliding_window_size == -1
|
272
|
+
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
273
|
+
return layer.is_cross_attention
|
274
|
+
|
275
|
+
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
|
276
|
+
|
277
|
+
|
278
|
+
class FlashInferIndicesUpdaterDecode:
|
279
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
280
|
+
# Constants
|
281
|
+
self.num_qo_heads = (
|
282
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
283
|
+
)
|
284
|
+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
285
|
+
model_runner.tp_size
|
286
|
+
)
|
287
|
+
self.head_dim = model_runner.model_config.head_dim
|
288
|
+
self.data_type = model_runner.kv_cache_dtype
|
289
|
+
self.q_data_type = model_runner.dtype
|
290
|
+
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
291
|
+
self.sliding_window_size = model_runner.sliding_window_size
|
292
|
+
|
293
|
+
# Buffers and wrappers
|
294
|
+
self.kv_indptr = attn_backend.kv_indptr
|
295
|
+
self.kv_last_page_len = attn_backend.kv_last_page_len
|
296
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
297
|
+
self.decode_wrappers = attn_backend.decode_wrappers
|
298
|
+
|
299
|
+
# Dispatch
|
300
|
+
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
301
|
+
self.update = self.update_sliding_window
|
302
|
+
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
303
|
+
self.update = self.update_cross_attention
|
304
|
+
else:
|
305
|
+
assert attn_backend.num_wrappers == 1
|
306
|
+
self.update = self.update_single_wrapper
|
307
|
+
|
308
|
+
def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None):
|
309
|
+
decode_wrappers = decode_wrappers or self.decode_wrappers
|
310
|
+
self.call_begin_forward(
|
311
|
+
decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None
|
312
|
+
)
|
313
|
+
|
314
|
+
def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None):
|
315
|
+
decode_wrappers = decode_wrappers or self.decode_wrappers
|
316
|
+
|
317
|
+
for wrapper_id in range(2):
|
318
|
+
if wrapper_id == 0:
|
319
|
+
# Sliding window attention
|
320
|
+
paged_kernel_lens = torch.minimum( # TODO: replace this with clamp
|
321
|
+
seq_lens,
|
322
|
+
torch.tensor(self.sliding_window_size + 1),
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
# Full attention
|
326
|
+
paged_kernel_lens = seq_lens
|
327
|
+
|
328
|
+
kv_start_idx = seq_lens - paged_kernel_lens
|
329
|
+
|
330
|
+
self.call_begin_forward(
|
331
|
+
decode_wrappers[wrapper_id],
|
332
|
+
req_pool_indices,
|
333
|
+
paged_kernel_lens,
|
334
|
+
self.kv_indptr[wrapper_id],
|
335
|
+
kv_start_idx,
|
336
|
+
)
|
337
|
+
|
338
|
+
def update_cross_attention(self):
|
339
|
+
raise NotImplementedError()
|
340
|
+
|
341
|
+
def call_begin_forward(
|
342
|
+
self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx
|
343
|
+
):
|
344
|
+
bs = len(req_pool_indices)
|
345
|
+
kv_indptr = kv_indptr[: bs + 1]
|
346
|
+
# TODO: optimize the blocking call on kv_indptr[-1]
|
347
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
348
|
+
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
349
|
+
|
350
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
351
|
+
self.req_to_token,
|
352
|
+
req_pool_indices,
|
353
|
+
paged_kernel_lens,
|
354
|
+
kv_indptr,
|
355
|
+
kv_start_idx,
|
356
|
+
kv_indices,
|
357
|
+
self.max_context_len,
|
358
|
+
)
|
359
|
+
|
360
|
+
wrapper.end_forward()
|
361
|
+
wrapper.begin_forward(
|
362
|
+
kv_indptr,
|
363
|
+
kv_indices,
|
364
|
+
self.kv_last_page_len[:bs],
|
365
|
+
self.num_qo_heads,
|
366
|
+
self.num_kv_heads,
|
367
|
+
self.head_dim,
|
368
|
+
1,
|
369
|
+
data_type=self.data_type,
|
370
|
+
q_data_type=self.q_data_type,
|
371
|
+
)
|
372
|
+
|
373
|
+
|
374
|
+
class FlashInferIndicesUpdaterPrefill:
|
375
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
376
|
+
# Constants
|
377
|
+
self.num_qo_heads = (
|
378
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
379
|
+
)
|
380
|
+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
381
|
+
model_runner.tp_size
|
382
|
+
)
|
383
|
+
self.head_dim = model_runner.model_config.head_dim
|
384
|
+
self.data_type = model_runner.kv_cache_dtype
|
385
|
+
self.q_data_type = model_runner.dtype
|
386
|
+
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
387
|
+
self.sliding_window_size = model_runner.sliding_window_size
|
388
|
+
|
389
|
+
# Buffers and wrappers
|
390
|
+
self.kv_indptr = attn_backend.kv_indptr
|
391
|
+
self.kv_last_page_len = attn_backend.kv_last_page_len
|
392
|
+
self.qo_indptr = attn_backend.qo_indptr
|
393
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
394
|
+
self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
395
|
+
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
396
|
+
|
397
|
+
# Dispatch
|
398
|
+
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
399
|
+
self.update = self.update_sliding_window
|
400
|
+
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
401
|
+
self.update = self.update_cross_attention
|
402
|
+
else:
|
403
|
+
assert attn_backend.num_wrappers == 1
|
404
|
+
self.update = self.update_single_wrapper
|
405
|
+
|
406
|
+
def update_single_wrapper(
|
407
|
+
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
408
|
+
):
|
409
|
+
if use_ragged:
|
410
|
+
paged_kernel_lens = prefix_lens
|
411
|
+
else:
|
412
|
+
paged_kernel_lens = seq_lens
|
413
|
+
|
414
|
+
self.call_begin_forward(
|
415
|
+
self.wrapper_ragged,
|
416
|
+
self.wrappers_paged[0],
|
417
|
+
req_pool_indices,
|
418
|
+
paged_kernel_lens,
|
419
|
+
seq_lens,
|
420
|
+
prefix_lens,
|
421
|
+
None,
|
422
|
+
self.kv_indptr[0],
|
423
|
+
self.qo_indptr[0],
|
424
|
+
use_ragged,
|
425
|
+
)
|
426
|
+
|
427
|
+
def update_sliding_window(
|
428
|
+
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
429
|
+
):
|
430
|
+
for wrapper_id in range(2):
|
431
|
+
if wrapper_id == 0:
|
432
|
+
# window attention use paged only
|
433
|
+
paged_kernel_lens = torch.minimum(
|
434
|
+
seq_lens,
|
435
|
+
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
|
436
|
+
)
|
437
|
+
else:
|
438
|
+
# full attention
|
439
|
+
paged_kernel_lens = seq_lens
|
440
|
+
kv_start_idx = seq_lens - paged_kernel_lens
|
441
|
+
|
442
|
+
self.call_begin_forward(
|
443
|
+
self.wrapper_ragged,
|
444
|
+
self.wrappers_paged[wrapper_id],
|
445
|
+
req_pool_indices,
|
446
|
+
paged_kernel_lens,
|
447
|
+
seq_lens,
|
448
|
+
prefix_lens,
|
449
|
+
kv_start_idx,
|
450
|
+
self.kv_indptr[wrapper_id],
|
451
|
+
self.qo_indptr[wrapper_id],
|
452
|
+
use_ragged,
|
453
|
+
)
|
454
|
+
|
455
|
+
def update_cross_attention(self):
|
456
|
+
raise NotImplementedError()
|
457
|
+
|
458
|
+
def call_begin_forward(
|
459
|
+
self,
|
460
|
+
wrapper_ragged,
|
461
|
+
wrapper_paged,
|
462
|
+
req_pool_indices,
|
463
|
+
paged_kernel_lens,
|
464
|
+
seq_lens,
|
465
|
+
prefix_lens,
|
466
|
+
kv_start_idx,
|
467
|
+
kv_indptr,
|
468
|
+
qo_indptr,
|
469
|
+
use_ragged,
|
470
|
+
):
|
471
|
+
bs = len(req_pool_indices)
|
472
|
+
kv_indptr = kv_indptr[: bs + 1]
|
473
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
474
|
+
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
475
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
476
|
+
self.req_to_token,
|
477
|
+
req_pool_indices,
|
478
|
+
paged_kernel_lens,
|
479
|
+
kv_indptr,
|
480
|
+
kv_start_idx,
|
481
|
+
kv_indices,
|
482
|
+
self.max_context_len,
|
483
|
+
)
|
484
|
+
|
485
|
+
qo_indptr = qo_indptr[: bs + 1]
|
486
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
487
|
+
|
488
|
+
# extend part
|
489
|
+
if use_ragged:
|
490
|
+
wrapper_ragged.end_forward()
|
491
|
+
wrapper_ragged.begin_forward(
|
492
|
+
qo_indptr,
|
493
|
+
qo_indptr,
|
494
|
+
self.num_qo_heads,
|
495
|
+
self.num_kv_heads,
|
496
|
+
self.head_dim,
|
497
|
+
)
|
498
|
+
|
499
|
+
# cached part
|
500
|
+
wrapper_paged.end_forward()
|
501
|
+
wrapper_paged.begin_forward(
|
502
|
+
qo_indptr,
|
503
|
+
kv_indptr,
|
504
|
+
kv_indices,
|
505
|
+
self.kv_last_page_len[:bs],
|
506
|
+
self.num_qo_heads,
|
507
|
+
self.num_kv_heads,
|
508
|
+
self.head_dim,
|
509
|
+
1,
|
510
|
+
)
|
511
|
+
|
512
|
+
|
513
|
+
@triton.jit
|
514
|
+
def create_flashinfer_kv_indices_triton(
|
515
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
516
|
+
req_pool_indices_ptr,
|
517
|
+
page_kernel_lens_ptr,
|
518
|
+
kv_indptr,
|
519
|
+
kv_start_idx,
|
520
|
+
kv_indices_ptr,
|
521
|
+
max_context_len: tl.constexpr,
|
522
|
+
):
|
523
|
+
BLOCK_SIZE: tl.constexpr = 512
|
524
|
+
pid = tl.program_id(axis=0)
|
525
|
+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
526
|
+
kv_indices_offset = tl.load(kv_indptr + pid)
|
527
|
+
|
528
|
+
kv_start = 0
|
529
|
+
kv_end = 0
|
530
|
+
if kv_start_idx:
|
531
|
+
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
532
|
+
kv_end = kv_start
|
533
|
+
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
534
|
+
|
535
|
+
req_to_token_ptr += req_pool_index * max_context_len
|
536
|
+
kv_indices_ptr += kv_indices_offset
|
537
|
+
|
538
|
+
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
539
|
+
st_offset = tl.arange(0, BLOCK_SIZE)
|
540
|
+
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
541
|
+
for _ in range(num_loop):
|
542
|
+
mask = ld_offset < kv_end
|
543
|
+
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
544
|
+
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
545
|
+
ld_offset += BLOCK_SIZE
|
546
|
+
st_offset += BLOCK_SIZE
|
@@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
40
40
|
|
41
41
|
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
42
42
|
|
43
|
+
self.device = model_runner.device
|
44
|
+
|
43
45
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
44
46
|
"""Init auxiliary variables for triton attention backend."""
|
45
47
|
|
@@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
51
53
|
attn_logits = torch.empty(
|
52
54
|
(self.num_head, total_num_tokens),
|
53
55
|
dtype=self.reduce_dtype,
|
54
|
-
device=
|
56
|
+
device=self.device,
|
55
57
|
)
|
56
58
|
|
57
59
|
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
@@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
67
69
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
68
70
|
|
69
71
|
self.cuda_graph_start_loc = torch.zeros(
|
70
|
-
(max_bs,), dtype=torch.int32, device=
|
72
|
+
(max_bs,), dtype=torch.int32, device=self.device
|
71
73
|
)
|
72
74
|
self.cuda_graph_attn_logits = torch.empty(
|
73
75
|
(
|
@@ -79,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
79
81
|
)
|
80
82
|
|
81
83
|
def init_forward_metadata_capture_cuda_graph(
|
82
|
-
self, bs: int, req_pool_indices, seq_lens
|
84
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
83
85
|
):
|
84
86
|
self.forward_metadata = (
|
85
87
|
self.cuda_graph_start_loc,
|
@@ -89,7 +91,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
89
91
|
)
|
90
92
|
|
91
93
|
def init_forward_metadata_replay_cuda_graph(
|
92
|
-
self, bs: int, req_pool_indices, seq_lens
|
94
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
93
95
|
):
|
94
96
|
self.cuda_graph_start_loc.zero_()
|
95
97
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|