sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +6 -25
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +104 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +58 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +117 -131
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +57 -44
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -1
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,480 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support different attention backends.
|
5
|
+
Now there are two backends: FlashInfer and Triton.
|
6
|
+
FlashInfer is faster and Triton is easier to customize.
|
7
|
+
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
|
+
"""
|
9
|
+
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
from typing import TYPE_CHECKING
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import torch.nn as nn
|
15
|
+
from flashinfer import (
|
16
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
17
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
18
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
19
|
+
)
|
20
|
+
from flashinfer.cascade import merge_state
|
21
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
22
|
+
|
23
|
+
from sglang.global_config import global_config
|
24
|
+
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
25
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
26
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
30
|
+
|
31
|
+
|
32
|
+
class AttentionBackend(ABC):
|
33
|
+
"""The base class of attention backends"""
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def init_forward_metadata(
|
37
|
+
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
38
|
+
):
|
39
|
+
"""Init the metadata for a forward pass."""
|
40
|
+
raise NotImplementedError()
|
41
|
+
|
42
|
+
def init_cuda_graph_state(self, max_bs: int):
|
43
|
+
"""Init the global shared states for cuda graph."""
|
44
|
+
raise NotImplementedError()
|
45
|
+
|
46
|
+
def init_forward_metadata_capture_cuda_graph(
|
47
|
+
self, bs: int, req_pool_indices, seq_lens
|
48
|
+
):
|
49
|
+
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
50
|
+
raise NotImplementedError()
|
51
|
+
|
52
|
+
def init_forward_metadata_replay_cuda_graph(
|
53
|
+
self, bs: int, req_pool_indices, seq_lens
|
54
|
+
):
|
55
|
+
"""Init the metadata for a forward pass for replying a cuda graph."""
|
56
|
+
raise NotImplementedError()
|
57
|
+
|
58
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
59
|
+
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
60
|
+
raise NotImplementedError()
|
61
|
+
|
62
|
+
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
63
|
+
"""Run forward on an attention layer."""
|
64
|
+
if input_metadata.forward_mode.is_decode():
|
65
|
+
return self.forward_decode(q, k, v, layer, input_metadata)
|
66
|
+
else:
|
67
|
+
return self.forward_extend(q, k, v, layer, input_metadata)
|
68
|
+
|
69
|
+
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
70
|
+
"""Run a forward for decode."""
|
71
|
+
raise NotImplementedError()
|
72
|
+
|
73
|
+
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
74
|
+
"""Run a forward for extend."""
|
75
|
+
raise NotImplementedError()
|
76
|
+
|
77
|
+
|
78
|
+
class FlashInferAttnBackend(AttentionBackend):
|
79
|
+
"""Flashinfer attention kernels."""
|
80
|
+
|
81
|
+
def __init__(self, model_runner: ModelRunner):
|
82
|
+
super().__init__()
|
83
|
+
self.model_runner = model_runner
|
84
|
+
|
85
|
+
local_num_qo_heads = (
|
86
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
87
|
+
)
|
88
|
+
local_num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
89
|
+
model_runner.tp_size
|
90
|
+
)
|
91
|
+
if (
|
92
|
+
not _grouped_size_compiled_for_decode_kernels(
|
93
|
+
local_num_qo_heads, local_num_kv_heads
|
94
|
+
)
|
95
|
+
or local_num_qo_heads // local_num_kv_heads > 4
|
96
|
+
):
|
97
|
+
self.decode_use_tensor_cores = True
|
98
|
+
else:
|
99
|
+
self.decode_use_tensor_cores = False
|
100
|
+
|
101
|
+
self.workspace_buffer = torch.empty(
|
102
|
+
global_config.flashinfer_workspace_size,
|
103
|
+
dtype=torch.uint8,
|
104
|
+
device="cuda",
|
105
|
+
)
|
106
|
+
|
107
|
+
if model_runner.sliding_window_size is None:
|
108
|
+
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
109
|
+
self.workspace_buffer, "NHD"
|
110
|
+
)
|
111
|
+
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
112
|
+
self.workspace_buffer, "NHD"
|
113
|
+
)
|
114
|
+
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
115
|
+
self.workspace_buffer,
|
116
|
+
"NHD",
|
117
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
118
|
+
)
|
119
|
+
else:
|
120
|
+
# Two wrappers: one for sliding window attention and one for full attention.
|
121
|
+
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
122
|
+
self.prefill_wrapper_ragged = None
|
123
|
+
self.prefill_wrapper_paged = []
|
124
|
+
self.decode_wrapper = []
|
125
|
+
for _ in range(2):
|
126
|
+
self.prefill_wrapper_paged.append(
|
127
|
+
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
128
|
+
)
|
129
|
+
self.decode_wrapper.append(
|
130
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
131
|
+
self.workspace_buffer,
|
132
|
+
"NHD",
|
133
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
134
|
+
)
|
135
|
+
)
|
136
|
+
|
137
|
+
self.forward_metadata = None
|
138
|
+
self.cuda_graph_metadata = {}
|
139
|
+
|
140
|
+
def init_forward_metadata(
|
141
|
+
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
142
|
+
):
|
143
|
+
if input_metadata.forward_mode.is_decode():
|
144
|
+
prefix_lens = None
|
145
|
+
use_ragged = False
|
146
|
+
total_num_tokens = None
|
147
|
+
else:
|
148
|
+
prefix_lens = input_metadata.extend_prefix_lens
|
149
|
+
|
150
|
+
# Some heuristics to check whether to use ragged forward
|
151
|
+
use_ragged = False
|
152
|
+
if (
|
153
|
+
int(torch.sum(input_metadata.seq_lens)) > 4096
|
154
|
+
and self.model_runner.sliding_window_size is None
|
155
|
+
):
|
156
|
+
use_ragged = True
|
157
|
+
|
158
|
+
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
159
|
+
|
160
|
+
update_flashinfer_indices(
|
161
|
+
input_metadata.forward_mode,
|
162
|
+
self.model_runner,
|
163
|
+
input_metadata.req_pool_indices,
|
164
|
+
input_metadata.seq_lens,
|
165
|
+
prefix_lens,
|
166
|
+
use_ragged=use_ragged,
|
167
|
+
)
|
168
|
+
|
169
|
+
self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
|
170
|
+
|
171
|
+
def init_cuda_graph_state(self, max_bs: int):
|
172
|
+
self.cuda_graph_kv_indptr = torch.zeros(
|
173
|
+
(max_bs + 1,), dtype=torch.int32, device="cuda"
|
174
|
+
)
|
175
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
176
|
+
(max_bs * self.model_runner.model_config.context_len,),
|
177
|
+
dtype=torch.int32,
|
178
|
+
device="cuda",
|
179
|
+
)
|
180
|
+
self.cuda_graph_kv_last_page_len = torch.ones(
|
181
|
+
(max_bs,), dtype=torch.int32, device="cuda"
|
182
|
+
)
|
183
|
+
|
184
|
+
if self.model_runner.sliding_window_size is not None:
|
185
|
+
self.cuda_graph_kv_indptr = [
|
186
|
+
self.cuda_graph_kv_indptr,
|
187
|
+
self.cuda_graph_kv_indptr.clone(),
|
188
|
+
]
|
189
|
+
self.cuda_graph_kv_indices = [
|
190
|
+
self.cuda_graph_kv_indices,
|
191
|
+
self.cuda_graph_kv_indices.clone(),
|
192
|
+
]
|
193
|
+
|
194
|
+
def init_forward_metadata_capture_cuda_graph(
|
195
|
+
self, bs: int, req_pool_indices, seq_lens
|
196
|
+
):
|
197
|
+
if self.model_runner.sliding_window_size is None:
|
198
|
+
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
199
|
+
self.workspace_buffer,
|
200
|
+
"NHD",
|
201
|
+
use_cuda_graph=True,
|
202
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
203
|
+
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
|
204
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices,
|
205
|
+
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
decode_wrapper = []
|
209
|
+
for i in range(2):
|
210
|
+
decode_wrapper.append(
|
211
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
212
|
+
self.workspace_buffer,
|
213
|
+
"NHD",
|
214
|
+
use_cuda_graph=True,
|
215
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
216
|
+
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
|
217
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
218
|
+
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
|
219
|
+
:bs
|
220
|
+
],
|
221
|
+
)
|
222
|
+
)
|
223
|
+
|
224
|
+
update_flashinfer_indices(
|
225
|
+
ForwardMode.DECODE,
|
226
|
+
self.model_runner,
|
227
|
+
req_pool_indices,
|
228
|
+
seq_lens,
|
229
|
+
None,
|
230
|
+
decode_wrapper,
|
231
|
+
)
|
232
|
+
|
233
|
+
self.cuda_graph_metadata[bs] = decode_wrapper
|
234
|
+
|
235
|
+
self.forward_metadata = (False, None, decode_wrapper)
|
236
|
+
|
237
|
+
def init_forward_metadata_replay_cuda_graph(
|
238
|
+
self, bs: int, req_pool_indices, seq_lens
|
239
|
+
):
|
240
|
+
update_flashinfer_indices(
|
241
|
+
ForwardMode.DECODE,
|
242
|
+
self.model_runner,
|
243
|
+
req_pool_indices[:bs],
|
244
|
+
seq_lens[:bs],
|
245
|
+
None,
|
246
|
+
self.cuda_graph_metadata[bs],
|
247
|
+
)
|
248
|
+
|
249
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
250
|
+
return 0
|
251
|
+
|
252
|
+
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
253
|
+
if not isinstance(self.prefill_wrapper_paged, list):
|
254
|
+
prefill_wrapper_paged = self.prefill_wrapper_paged
|
255
|
+
else:
|
256
|
+
if layer.sliding_window_size != -1:
|
257
|
+
prefill_wrapper_paged = self.prefill_wrapper_paged[0]
|
258
|
+
else:
|
259
|
+
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
|
260
|
+
|
261
|
+
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
262
|
+
|
263
|
+
if not use_ragged:
|
264
|
+
if k is not None:
|
265
|
+
assert v is not None
|
266
|
+
input_metadata.token_to_kv_pool.set_kv_buffer(
|
267
|
+
layer.layer_id, input_metadata.out_cache_loc, k, v
|
268
|
+
)
|
269
|
+
o = prefill_wrapper_paged.forward(
|
270
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
271
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
272
|
+
causal=True,
|
273
|
+
sm_scale=layer.scaling,
|
274
|
+
window_left=layer.sliding_window_size,
|
275
|
+
logits_soft_cap=layer.logit_cap,
|
276
|
+
)
|
277
|
+
else:
|
278
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
279
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
280
|
+
k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
|
281
|
+
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
282
|
+
causal=True,
|
283
|
+
sm_scale=layer.scaling,
|
284
|
+
logits_soft_cap=layer.logit_cap,
|
285
|
+
)
|
286
|
+
|
287
|
+
if input_metadata.extend_no_prefix:
|
288
|
+
o = o1
|
289
|
+
else:
|
290
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
291
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
292
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
293
|
+
causal=False,
|
294
|
+
sm_scale=layer.scaling,
|
295
|
+
logits_soft_cap=layer.logit_cap,
|
296
|
+
)
|
297
|
+
|
298
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
299
|
+
|
300
|
+
input_metadata.token_to_kv_pool.set_kv_buffer(
|
301
|
+
layer.layer_id, input_metadata.out_cache_loc, k, v
|
302
|
+
)
|
303
|
+
|
304
|
+
if total_num_tokens >= global_config.layer_sync_threshold:
|
305
|
+
# TODO: Revisit this. Why is this synchronize needed?
|
306
|
+
torch.cuda.synchronize()
|
307
|
+
|
308
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
309
|
+
|
310
|
+
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
311
|
+
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
312
|
+
|
313
|
+
if isinstance(decode_wrapper, list):
|
314
|
+
if layer.sliding_window_size != -1:
|
315
|
+
decode_wrapper = decode_wrapper[0]
|
316
|
+
else:
|
317
|
+
decode_wrapper = decode_wrapper[1]
|
318
|
+
|
319
|
+
if k is not None:
|
320
|
+
assert v is not None
|
321
|
+
input_metadata.token_to_kv_pool.set_kv_buffer(
|
322
|
+
layer.layer_id, input_metadata.out_cache_loc, k, v
|
323
|
+
)
|
324
|
+
|
325
|
+
o = decode_wrapper.forward(
|
326
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
327
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
328
|
+
sm_scale=layer.scaling,
|
329
|
+
logits_soft_cap=layer.logit_cap,
|
330
|
+
)
|
331
|
+
|
332
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
333
|
+
|
334
|
+
|
335
|
+
class TritonAttnBackend(AttentionBackend):
|
336
|
+
def __init__(self, model_runner: ModelRunner):
|
337
|
+
# Lazy import to avoid the initialization of cuda context
|
338
|
+
from sglang.srt.layers.triton_attention.decode_attention import (
|
339
|
+
decode_attention_fwd,
|
340
|
+
)
|
341
|
+
from sglang.srt.layers.triton_attention.extend_attention import (
|
342
|
+
extend_attention_fwd,
|
343
|
+
)
|
344
|
+
|
345
|
+
super().__init__()
|
346
|
+
|
347
|
+
self.decode_attention_fwd = decode_attention_fwd
|
348
|
+
self.extend_attention_fwd = extend_attention_fwd
|
349
|
+
self.num_head = model_runner.model_config.num_attention_heads
|
350
|
+
|
351
|
+
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
352
|
+
self.reduce_dtype = torch.float32
|
353
|
+
else:
|
354
|
+
self.reduce_dtype = torch.float16
|
355
|
+
|
356
|
+
self.forward_metadata = None
|
357
|
+
|
358
|
+
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
359
|
+
|
360
|
+
def init_forward_metadata(
|
361
|
+
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
362
|
+
):
|
363
|
+
"""Init auxiliary variables for triton attention backend."""
|
364
|
+
|
365
|
+
if input_metadata.forward_mode.is_decode():
|
366
|
+
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
|
367
|
+
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
|
368
|
+
|
369
|
+
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
370
|
+
attn_logits = torch.empty(
|
371
|
+
(self.num_head, total_num_tokens),
|
372
|
+
dtype=self.reduce_dtype,
|
373
|
+
device="cuda",
|
374
|
+
)
|
375
|
+
|
376
|
+
max_seq_len = torch.max(input_metadata.seq_lens).item()
|
377
|
+
max_extend_len = None
|
378
|
+
else:
|
379
|
+
start_loc = attn_logits = max_seq_len = None
|
380
|
+
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
381
|
+
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
382
|
+
|
383
|
+
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
384
|
+
|
385
|
+
def init_cuda_graph_state(self, max_bs: int):
|
386
|
+
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
387
|
+
|
388
|
+
self.cuda_graph_start_loc = torch.zeros(
|
389
|
+
(max_bs,), dtype=torch.int32, device="cuda"
|
390
|
+
)
|
391
|
+
self.cuda_graph_attn_logits = torch.empty(
|
392
|
+
(
|
393
|
+
self.num_head,
|
394
|
+
self.cuda_graph_max_total_num_tokens,
|
395
|
+
),
|
396
|
+
dtype=self.reduce_dtype,
|
397
|
+
device="cuda",
|
398
|
+
)
|
399
|
+
|
400
|
+
def init_forward_metadata_capture_cuda_graph(
|
401
|
+
self, bs: int, req_pool_indices, seq_lens
|
402
|
+
):
|
403
|
+
self.forward_metadata = (
|
404
|
+
self.cuda_graph_start_loc,
|
405
|
+
self.cuda_graph_attn_logits,
|
406
|
+
self.cuda_graph_max_seq_len,
|
407
|
+
None,
|
408
|
+
)
|
409
|
+
|
410
|
+
def init_forward_metadata_replay_cuda_graph(
|
411
|
+
self, bs: int, req_pool_indices, seq_lens
|
412
|
+
):
|
413
|
+
self.cuda_graph_start_loc.zero_()
|
414
|
+
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
415
|
+
|
416
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
417
|
+
return 1
|
418
|
+
|
419
|
+
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
420
|
+
# TODO: reuse the buffer across layers
|
421
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
422
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
423
|
+
else:
|
424
|
+
o = torch.empty_like(q)
|
425
|
+
|
426
|
+
input_metadata.token_to_kv_pool.set_kv_buffer(
|
427
|
+
layer.layer_id, input_metadata.out_cache_loc, k, v
|
428
|
+
)
|
429
|
+
|
430
|
+
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
431
|
+
self.extend_attention_fwd(
|
432
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
433
|
+
k.contiguous(),
|
434
|
+
v.contiguous(),
|
435
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
436
|
+
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
437
|
+
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
438
|
+
input_metadata.req_to_token_pool.req_to_token,
|
439
|
+
input_metadata.req_pool_indices,
|
440
|
+
input_metadata.seq_lens,
|
441
|
+
input_metadata.extend_seq_lens,
|
442
|
+
input_metadata.extend_start_loc,
|
443
|
+
max_extend_len,
|
444
|
+
layer.scaling,
|
445
|
+
layer.logit_cap,
|
446
|
+
)
|
447
|
+
return o
|
448
|
+
|
449
|
+
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
450
|
+
# During torch.compile, there is a bug in rotary_emb that causes the
|
451
|
+
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
452
|
+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
453
|
+
|
454
|
+
# TODO: reuse the buffer across layers
|
455
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
456
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
457
|
+
else:
|
458
|
+
o = torch.empty_like(q)
|
459
|
+
|
460
|
+
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
461
|
+
|
462
|
+
input_metadata.token_to_kv_pool.set_kv_buffer(
|
463
|
+
layer.layer_id, input_metadata.out_cache_loc, k, v
|
464
|
+
)
|
465
|
+
|
466
|
+
self.decode_attention_fwd(
|
467
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
468
|
+
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
469
|
+
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
470
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
471
|
+
input_metadata.req_to_token_pool.req_to_token,
|
472
|
+
input_metadata.req_pool_indices,
|
473
|
+
start_loc,
|
474
|
+
input_metadata.seq_lens,
|
475
|
+
attn_logits,
|
476
|
+
max_seq_len,
|
477
|
+
layer.scaling,
|
478
|
+
layer.logit_cap,
|
479
|
+
)
|
480
|
+
return o
|