sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +13 -6
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +2 -4
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +40 -35
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +110 -74
- sglang/srt/managers/tokenizer_manager.py +24 -15
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +60 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +118 -141
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +6 -8
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +8 -43
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/{llama2.py → llama.py} +48 -26
- sglang/srt/models/llama_classification.py +14 -40
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +38 -16
- sglang/srt/models/llavavid.py +7 -8
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +67 -58
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +70 -0
- sglang/test/test_utils.py +89 -1
- sglang/utils.py +38 -4
- sglang/version.py +1 -1
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.2.15.dist-info/RECORD +0 -118
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -15,22 +15,19 @@ See the License for the specific language governing permissions and
|
|
15
15
|
limitations under the License.
|
16
16
|
"""
|
17
17
|
|
18
|
-
"""
|
18
|
+
"""Meta data for a forward pass."""
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from enum import IntEnum, auto
|
21
21
|
from typing import TYPE_CHECKING, List
|
22
22
|
|
23
23
|
import numpy as np
|
24
24
|
import torch
|
25
|
-
import triton
|
26
|
-
import triton.language as tl
|
27
|
-
|
28
|
-
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
29
|
-
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
30
25
|
|
31
26
|
if TYPE_CHECKING:
|
27
|
+
from sglang.srt.layers.attention_backend import AttentionBackend
|
28
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
29
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
30
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
33
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
34
31
|
|
35
32
|
|
36
33
|
class ForwardMode(IntEnum):
|
@@ -40,6 +37,20 @@ class ForwardMode(IntEnum):
|
|
40
37
|
EXTEND = auto()
|
41
38
|
# Decode one token.
|
42
39
|
DECODE = auto()
|
40
|
+
# Contains both PREFILL and EXTEND.
|
41
|
+
MIXED = auto()
|
42
|
+
|
43
|
+
def is_prefill(self):
|
44
|
+
return self == ForwardMode.PREFILL
|
45
|
+
|
46
|
+
def is_extend(self):
|
47
|
+
return self == ForwardMode.EXTEND or self == ForwardMode.MIXED
|
48
|
+
|
49
|
+
def is_decode(self):
|
50
|
+
return self == ForwardMode.DECODE
|
51
|
+
|
52
|
+
def is_mixed(self):
|
53
|
+
return self == ForwardMode.MIXED
|
43
54
|
|
44
55
|
|
45
56
|
@dataclass
|
@@ -47,18 +58,16 @@ class InputMetadata:
|
|
47
58
|
"""Store all inforamtion of a forward pass."""
|
48
59
|
|
49
60
|
forward_mode: ForwardMode
|
50
|
-
sampling_info: SamplingBatchInfo
|
51
61
|
batch_size: int
|
52
62
|
req_pool_indices: torch.Tensor
|
53
63
|
seq_lens: torch.Tensor
|
54
64
|
req_to_token_pool: ReqToTokenPool
|
55
65
|
token_to_kv_pool: BaseTokenToKVPool
|
66
|
+
attn_backend: AttentionBackend
|
56
67
|
|
57
68
|
# Output location of the KV cache
|
58
69
|
out_cache_loc: torch.Tensor
|
59
70
|
|
60
|
-
total_num_tokens: int = None
|
61
|
-
|
62
71
|
# Position information
|
63
72
|
positions: torch.Tensor = None
|
64
73
|
|
@@ -72,35 +81,25 @@ class InputMetadata:
|
|
72
81
|
return_logprob: bool = False
|
73
82
|
top_logprobs_nums: List[int] = None
|
74
83
|
extend_seq_lens_cpu: List[int] = None
|
75
|
-
|
84
|
+
extend_logprob_start_lens_cpu: List[int] = None
|
76
85
|
|
77
86
|
# For multimodal
|
78
87
|
pixel_values: List[torch.Tensor] = None
|
79
88
|
image_sizes: List[List[List[int]]] = None
|
80
89
|
image_offsets: List[List[int]] = None
|
81
|
-
|
82
|
-
# Trition attention backend
|
83
|
-
triton_max_seq_len: int = 0
|
84
|
-
triton_max_extend_len: int = 0
|
85
|
-
triton_start_loc: torch.Tensor = None
|
86
|
-
triton_prefix_lens: torch.Tensor = None
|
87
|
-
|
88
|
-
# FlashInfer attention backend
|
89
|
-
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
90
|
-
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
91
|
-
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
92
|
-
flashinfer_use_ragged: bool = False
|
90
|
+
modalities: List[List[str]] = None
|
93
91
|
|
94
92
|
def init_multimuldal_info(self, batch: ScheduleBatch):
|
95
93
|
reqs = batch.reqs
|
96
94
|
self.pixel_values = [r.pixel_values for r in reqs]
|
97
95
|
self.image_sizes = [r.image_sizes for r in reqs]
|
98
96
|
self.image_offsets = [r.image_offsets for r in reqs]
|
97
|
+
self.modalities = [r.modalities for r in reqs]
|
99
98
|
|
100
99
|
def compute_positions(self, batch: ScheduleBatch):
|
101
100
|
position_ids_offsets = batch.position_ids_offsets
|
102
101
|
|
103
|
-
if self.forward_mode
|
102
|
+
if self.forward_mode.is_decode():
|
104
103
|
if True:
|
105
104
|
self.positions = self.seq_lens - 1
|
106
105
|
else:
|
@@ -139,315 +138,39 @@ class InputMetadata:
|
|
139
138
|
self.positions = self.positions.to(torch.int64)
|
140
139
|
|
141
140
|
def compute_extend_infos(self, batch: ScheduleBatch):
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
]
|
150
|
-
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
151
|
-
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
152
|
-
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
153
|
-
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
154
|
-
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
155
|
-
|
156
|
-
self.extend_seq_lens_cpu = extend_lens_cpu
|
157
|
-
self.logprob_start_lens_cpu = [
|
158
|
-
(
|
159
|
-
min(
|
160
|
-
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
161
|
-
extend_lens_cpu[i] - 1,
|
162
|
-
)
|
163
|
-
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
164
|
-
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
165
|
-
)
|
166
|
-
for i, req in enumerate(batch.reqs)
|
167
|
-
]
|
141
|
+
self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
142
|
+
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
143
|
+
self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
|
144
|
+
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
145
|
+
self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
|
146
|
+
self.extend_seq_lens_cpu = batch.extend_lens_cpu
|
147
|
+
self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
168
148
|
|
169
149
|
@classmethod
|
170
150
|
def from_schedule_batch(
|
171
151
|
cls,
|
172
152
|
model_runner: "ModelRunner",
|
173
153
|
batch: ScheduleBatch,
|
174
|
-
forward_mode: ForwardMode,
|
175
154
|
):
|
176
155
|
ret = cls(
|
177
|
-
forward_mode=forward_mode,
|
178
|
-
sampling_info=batch.sampling_info,
|
156
|
+
forward_mode=batch.forward_mode,
|
179
157
|
batch_size=batch.batch_size(),
|
180
158
|
req_pool_indices=batch.req_pool_indices,
|
181
159
|
seq_lens=batch.seq_lens,
|
182
160
|
req_to_token_pool=model_runner.req_to_token_pool,
|
183
161
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
162
|
+
attn_backend=model_runner.attn_backend,
|
184
163
|
out_cache_loc=batch.out_cache_loc,
|
185
164
|
return_logprob=batch.return_logprob,
|
186
165
|
top_logprobs_nums=batch.top_logprobs_nums,
|
187
166
|
)
|
188
167
|
|
189
|
-
ret.sampling_info.prepare_penalties()
|
190
|
-
|
191
168
|
ret.compute_positions(batch)
|
192
169
|
|
193
|
-
|
194
|
-
|
195
|
-
if (
|
196
|
-
forward_mode != ForwardMode.DECODE
|
197
|
-
or model_runner.server_args.disable_flashinfer
|
198
|
-
):
|
199
|
-
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
200
|
-
|
201
|
-
if forward_mode != ForwardMode.DECODE:
|
170
|
+
if not batch.forward_mode.is_decode():
|
202
171
|
ret.init_multimuldal_info(batch)
|
172
|
+
ret.compute_extend_infos(batch)
|
203
173
|
|
204
|
-
|
205
|
-
ret.init_triton_args(batch)
|
206
|
-
|
207
|
-
flashinfer_use_ragged = False
|
208
|
-
if not model_runner.server_args.disable_flashinfer:
|
209
|
-
if (
|
210
|
-
forward_mode != ForwardMode.DECODE
|
211
|
-
and int(torch.sum(ret.seq_lens)) > 4096
|
212
|
-
and model_runner.sliding_window_size is None
|
213
|
-
):
|
214
|
-
flashinfer_use_ragged = True
|
215
|
-
ret.init_flashinfer_handlers(
|
216
|
-
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
|
217
|
-
)
|
174
|
+
model_runner.attn_backend.init_forward_metadata(batch, ret)
|
218
175
|
|
219
176
|
return ret
|
220
|
-
|
221
|
-
def init_triton_args(self, batch: ScheduleBatch):
|
222
|
-
"""Init auxiliary variables for triton attention backend."""
|
223
|
-
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
224
|
-
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
225
|
-
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
226
|
-
|
227
|
-
if self.forward_mode == ForwardMode.DECODE:
|
228
|
-
self.triton_max_extend_len = None
|
229
|
-
else:
|
230
|
-
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
231
|
-
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
|
232
|
-
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
233
|
-
|
234
|
-
def init_flashinfer_handlers(
|
235
|
-
self,
|
236
|
-
model_runner,
|
237
|
-
prefix_lens_cpu,
|
238
|
-
flashinfer_use_ragged,
|
239
|
-
):
|
240
|
-
if self.forward_mode == ForwardMode.DECODE:
|
241
|
-
prefix_lens = None
|
242
|
-
else:
|
243
|
-
prefix_lens = self.extend_prefix_lens
|
244
|
-
|
245
|
-
update_flashinfer_indices(
|
246
|
-
self.forward_mode,
|
247
|
-
model_runner,
|
248
|
-
self.req_pool_indices,
|
249
|
-
self.seq_lens,
|
250
|
-
prefix_lens,
|
251
|
-
flashinfer_use_ragged=flashinfer_use_ragged,
|
252
|
-
)
|
253
|
-
|
254
|
-
(
|
255
|
-
self.flashinfer_prefill_wrapper_ragged,
|
256
|
-
self.flashinfer_prefill_wrapper_paged,
|
257
|
-
self.flashinfer_decode_wrapper,
|
258
|
-
self.flashinfer_use_ragged,
|
259
|
-
) = (
|
260
|
-
model_runner.flashinfer_prefill_wrapper_ragged,
|
261
|
-
model_runner.flashinfer_prefill_wrapper_paged,
|
262
|
-
model_runner.flashinfer_decode_wrapper,
|
263
|
-
flashinfer_use_ragged,
|
264
|
-
)
|
265
|
-
|
266
|
-
|
267
|
-
@triton.jit
|
268
|
-
def create_flashinfer_kv_indices_triton(
|
269
|
-
req_to_token_ptr, # [max_batch, max_context_len]
|
270
|
-
req_pool_indices_ptr,
|
271
|
-
page_kernel_lens_ptr,
|
272
|
-
kv_indptr,
|
273
|
-
kv_start_idx,
|
274
|
-
max_context_len,
|
275
|
-
kv_indices_ptr,
|
276
|
-
):
|
277
|
-
BLOCK_SIZE: tl.constexpr = 512
|
278
|
-
pid = tl.program_id(axis=0)
|
279
|
-
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
280
|
-
kv_indices_offset = tl.load(kv_indptr + pid)
|
281
|
-
|
282
|
-
kv_start = 0
|
283
|
-
kv_end = 0
|
284
|
-
if kv_start_idx:
|
285
|
-
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
286
|
-
kv_end = kv_start
|
287
|
-
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
288
|
-
|
289
|
-
req_to_token_ptr += req_pool_index * max_context_len
|
290
|
-
kv_indices_ptr += kv_indices_offset
|
291
|
-
|
292
|
-
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
293
|
-
st_offset = tl.arange(0, BLOCK_SIZE)
|
294
|
-
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
295
|
-
for _ in range(num_loop):
|
296
|
-
mask = ld_offset < kv_end
|
297
|
-
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
298
|
-
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
299
|
-
ld_offset += BLOCK_SIZE
|
300
|
-
st_offset += BLOCK_SIZE
|
301
|
-
|
302
|
-
|
303
|
-
def update_flashinfer_indices(
|
304
|
-
forward_mode,
|
305
|
-
model_runner,
|
306
|
-
req_pool_indices,
|
307
|
-
seq_lens,
|
308
|
-
prefix_lens,
|
309
|
-
flashinfer_decode_wrapper=None,
|
310
|
-
flashinfer_use_ragged=False,
|
311
|
-
):
|
312
|
-
"""Init auxiliary variables for FlashInfer attention backend."""
|
313
|
-
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
314
|
-
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
315
|
-
head_dim = model_runner.model_config.head_dim
|
316
|
-
batch_size = len(req_pool_indices)
|
317
|
-
|
318
|
-
if model_runner.sliding_window_size is None:
|
319
|
-
if flashinfer_use_ragged:
|
320
|
-
paged_kernel_lens = prefix_lens
|
321
|
-
else:
|
322
|
-
paged_kernel_lens = seq_lens
|
323
|
-
|
324
|
-
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
325
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
326
|
-
|
327
|
-
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
328
|
-
create_flashinfer_kv_indices_triton[(batch_size,)](
|
329
|
-
model_runner.req_to_token_pool.req_to_token,
|
330
|
-
req_pool_indices,
|
331
|
-
paged_kernel_lens,
|
332
|
-
kv_indptr,
|
333
|
-
None,
|
334
|
-
model_runner.req_to_token_pool.req_to_token.size(1),
|
335
|
-
kv_indices,
|
336
|
-
)
|
337
|
-
|
338
|
-
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
339
|
-
|
340
|
-
if forward_mode == ForwardMode.DECODE:
|
341
|
-
# CUDA graph uses different flashinfer_decode_wrapper
|
342
|
-
if flashinfer_decode_wrapper is None:
|
343
|
-
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
344
|
-
|
345
|
-
flashinfer_decode_wrapper.end_forward()
|
346
|
-
flashinfer_decode_wrapper.begin_forward(
|
347
|
-
kv_indptr,
|
348
|
-
kv_indices,
|
349
|
-
kv_last_page_len,
|
350
|
-
num_qo_heads,
|
351
|
-
num_kv_heads,
|
352
|
-
head_dim,
|
353
|
-
1,
|
354
|
-
data_type=model_runner.kv_cache_dtype,
|
355
|
-
q_data_type=model_runner.dtype,
|
356
|
-
)
|
357
|
-
else:
|
358
|
-
# extend part
|
359
|
-
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
360
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
361
|
-
|
362
|
-
if flashinfer_use_ragged:
|
363
|
-
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
364
|
-
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
365
|
-
qo_indptr,
|
366
|
-
qo_indptr,
|
367
|
-
num_qo_heads,
|
368
|
-
num_kv_heads,
|
369
|
-
head_dim,
|
370
|
-
)
|
371
|
-
|
372
|
-
# cached part
|
373
|
-
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
374
|
-
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
375
|
-
qo_indptr,
|
376
|
-
kv_indptr,
|
377
|
-
kv_indices,
|
378
|
-
kv_last_page_len,
|
379
|
-
num_qo_heads,
|
380
|
-
num_kv_heads,
|
381
|
-
head_dim,
|
382
|
-
1,
|
383
|
-
)
|
384
|
-
else:
|
385
|
-
# window attention use paged only
|
386
|
-
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
387
|
-
for wrapper_id in range(2):
|
388
|
-
if wrapper_id == 0:
|
389
|
-
if forward_mode == ForwardMode.DECODE:
|
390
|
-
paged_kernel_lens = torch.minimum(
|
391
|
-
seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
|
392
|
-
)
|
393
|
-
else:
|
394
|
-
paged_kernel_lens = torch.minimum(
|
395
|
-
seq_lens,
|
396
|
-
torch.tensor(model_runner.sliding_window_size)
|
397
|
-
+ seq_lens
|
398
|
-
- prefix_lens,
|
399
|
-
)
|
400
|
-
else:
|
401
|
-
paged_kernel_lens = seq_lens
|
402
|
-
|
403
|
-
kv_start_idx = seq_lens - paged_kernel_lens
|
404
|
-
|
405
|
-
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
406
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
407
|
-
|
408
|
-
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
409
|
-
create_flashinfer_kv_indices_triton[(batch_size,)](
|
410
|
-
model_runner.req_to_token_pool.req_to_token,
|
411
|
-
req_pool_indices,
|
412
|
-
paged_kernel_lens,
|
413
|
-
kv_indptr,
|
414
|
-
kv_start_idx,
|
415
|
-
model_runner.req_to_token_pool.req_to_token.size(1),
|
416
|
-
kv_indices,
|
417
|
-
)
|
418
|
-
|
419
|
-
if forward_mode == ForwardMode.DECODE:
|
420
|
-
# CUDA graph uses different flashinfer_decode_wrapper
|
421
|
-
if flashinfer_decode_wrapper is None:
|
422
|
-
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
423
|
-
|
424
|
-
flashinfer_decode_wrapper[wrapper_id].end_forward()
|
425
|
-
flashinfer_decode_wrapper[wrapper_id].begin_forward(
|
426
|
-
kv_indptr,
|
427
|
-
kv_indices,
|
428
|
-
kv_last_page_len,
|
429
|
-
num_qo_heads,
|
430
|
-
num_kv_heads,
|
431
|
-
head_dim,
|
432
|
-
1,
|
433
|
-
data_type=model_runner.kv_cache_dtype,
|
434
|
-
q_data_type=model_runner.dtype,
|
435
|
-
)
|
436
|
-
else:
|
437
|
-
# extend part
|
438
|
-
qo_indptr = torch.zeros(
|
439
|
-
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
440
|
-
)
|
441
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
442
|
-
|
443
|
-
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
|
444
|
-
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
|
445
|
-
qo_indptr,
|
446
|
-
kv_indptr,
|
447
|
-
kv_indices,
|
448
|
-
kv_last_page_len,
|
449
|
-
num_qo_heads,
|
450
|
-
num_kv_heads,
|
451
|
-
head_dim,
|
452
|
-
1,
|
453
|
-
)
|