sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -25,16 +25,18 @@ from vllm.distributed.parallel_state import graph_capture
|
|
25
25
|
from vllm.model_executor.custom_op import CustomOp
|
26
26
|
|
27
27
|
from sglang.srt.layers.logits_processor import (
|
28
|
-
LogitProcessorOutput,
|
29
28
|
LogitsMetadata,
|
30
29
|
LogitsProcessor,
|
30
|
+
LogitsProcessorOutput,
|
31
31
|
)
|
32
|
+
from sglang.srt.layers.sampler import SampleOutput
|
32
33
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
33
34
|
from sglang.srt.model_executor.forward_batch_info import (
|
34
35
|
ForwardMode,
|
35
36
|
InputMetadata,
|
36
37
|
update_flashinfer_indices,
|
37
38
|
)
|
39
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
38
40
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
39
41
|
|
40
42
|
|
@@ -84,13 +86,20 @@ def set_torch_compile_config():
|
|
84
86
|
|
85
87
|
|
86
88
|
class CudaGraphRunner:
|
87
|
-
def __init__(
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
model_runner,
|
92
|
+
max_batch_size_to_capture: int,
|
93
|
+
use_torch_compile: bool,
|
94
|
+
disable_padding: bool,
|
95
|
+
):
|
88
96
|
self.model_runner = model_runner
|
89
97
|
self.graphs = {}
|
90
98
|
self.input_buffers = {}
|
91
99
|
self.output_buffers = {}
|
92
100
|
self.flashinfer_handlers = {}
|
93
101
|
self.graph_memory_pool = None
|
102
|
+
self.disable_padding = disable_padding
|
94
103
|
|
95
104
|
# Common inputs
|
96
105
|
self.max_bs = max_batch_size_to_capture
|
@@ -98,8 +107,8 @@ class CudaGraphRunner:
|
|
98
107
|
self.req_pool_indices = torch.zeros(
|
99
108
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
100
109
|
)
|
101
|
-
self.seq_lens = torch.
|
102
|
-
self.position_ids_offsets = torch.
|
110
|
+
self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
111
|
+
self.position_ids_offsets = torch.ones(
|
103
112
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
104
113
|
)
|
105
114
|
self.out_cache_loc = torch.zeros(
|
@@ -107,9 +116,6 @@ class CudaGraphRunner:
|
|
107
116
|
)
|
108
117
|
|
109
118
|
# FlashInfer inputs
|
110
|
-
self.flashinfer_workspace_buffer = (
|
111
|
-
self.model_runner.flashinfer_workspace_buffers[0]
|
112
|
-
)
|
113
119
|
self.flashinfer_kv_indptr = torch.zeros(
|
114
120
|
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
115
121
|
)
|
@@ -121,6 +127,27 @@ class CudaGraphRunner:
|
|
121
127
|
self.flashinfer_kv_last_page_len = torch.ones(
|
122
128
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
123
129
|
)
|
130
|
+
if model_runner.sliding_window_size is None:
|
131
|
+
self.flashinfer_workspace_buffer = (
|
132
|
+
self.model_runner.flashinfer_workspace_buffer
|
133
|
+
)
|
134
|
+
else:
|
135
|
+
self.flashinfer_workspace_buffer = (
|
136
|
+
self.model_runner.flashinfer_workspace_buffer
|
137
|
+
)
|
138
|
+
|
139
|
+
self.flashinfer_kv_indptr = [
|
140
|
+
self.flashinfer_kv_indptr,
|
141
|
+
self.flashinfer_kv_indptr.clone(),
|
142
|
+
]
|
143
|
+
self.flashinfer_kv_indices = [
|
144
|
+
self.flashinfer_kv_indices,
|
145
|
+
self.flashinfer_kv_indices.clone(),
|
146
|
+
]
|
147
|
+
|
148
|
+
# Sampling inputs
|
149
|
+
vocab_size = model_runner.model_config.vocab_size
|
150
|
+
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
124
151
|
|
125
152
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
126
153
|
|
@@ -128,7 +155,10 @@ class CudaGraphRunner:
|
|
128
155
|
set_torch_compile_config()
|
129
156
|
|
130
157
|
def can_run(self, batch_size):
|
131
|
-
|
158
|
+
if self.disable_padding:
|
159
|
+
return batch_size in self.graphs
|
160
|
+
else:
|
161
|
+
return batch_size <= self.max_bs
|
132
162
|
|
133
163
|
def capture(self, batch_size_list):
|
134
164
|
self.batch_size_list = batch_size_list
|
@@ -171,15 +201,32 @@ class CudaGraphRunner:
|
|
171
201
|
use_tensor_cores = True
|
172
202
|
else:
|
173
203
|
use_tensor_cores = False
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
204
|
+
if self.model_runner.sliding_window_size is None:
|
205
|
+
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
206
|
+
self.flashinfer_workspace_buffer,
|
207
|
+
"NHD",
|
208
|
+
use_cuda_graph=True,
|
209
|
+
use_tensor_cores=use_tensor_cores,
|
210
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
211
|
+
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
212
|
+
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
213
|
+
)
|
214
|
+
else:
|
215
|
+
flashinfer_decode_wrapper = []
|
216
|
+
for i in range(2):
|
217
|
+
flashinfer_decode_wrapper.append(
|
218
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
219
|
+
self.flashinfer_workspace_buffer,
|
220
|
+
"NHD",
|
221
|
+
use_cuda_graph=True,
|
222
|
+
use_tensor_cores=use_tensor_cores,
|
223
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
224
|
+
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
225
|
+
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
226
|
+
:bs
|
227
|
+
],
|
228
|
+
)
|
229
|
+
)
|
183
230
|
update_flashinfer_indices(
|
184
231
|
ForwardMode.DECODE,
|
185
232
|
self.model_runner,
|
@@ -193,6 +240,7 @@ class CudaGraphRunner:
|
|
193
240
|
def run_once():
|
194
241
|
input_metadata = InputMetadata(
|
195
242
|
forward_mode=ForwardMode.DECODE,
|
243
|
+
sampling_info=self.sampling_info[:bs],
|
196
244
|
batch_size=bs,
|
197
245
|
req_pool_indices=req_pool_indices,
|
198
246
|
seq_lens=seq_lens,
|
@@ -201,19 +249,30 @@ class CudaGraphRunner:
|
|
201
249
|
out_cache_loc=out_cache_loc,
|
202
250
|
return_logprob=False,
|
203
251
|
top_logprobs_nums=0,
|
204
|
-
positions=(seq_lens - 1).to(torch.int64),
|
252
|
+
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
205
253
|
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
206
254
|
)
|
207
255
|
|
208
256
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
209
257
|
|
210
258
|
for _ in range(2):
|
259
|
+
torch.cuda.synchronize()
|
260
|
+
self.model_runner.tp_group.barrier()
|
261
|
+
|
211
262
|
run_once()
|
212
263
|
|
264
|
+
torch.cuda.synchronize()
|
265
|
+
self.model_runner.tp_group.barrier()
|
266
|
+
|
213
267
|
torch.cuda.synchronize()
|
268
|
+
self.model_runner.tp_group.barrier()
|
269
|
+
|
214
270
|
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
215
271
|
out = run_once()
|
272
|
+
|
216
273
|
torch.cuda.synchronize()
|
274
|
+
self.model_runner.tp_group.barrier()
|
275
|
+
|
217
276
|
self.graph_memory_pool = graph.pool()
|
218
277
|
return graph, None, out, flashinfer_decode_wrapper
|
219
278
|
|
@@ -225,8 +284,8 @@ class CudaGraphRunner:
|
|
225
284
|
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
226
285
|
bs = self.batch_size_list[index]
|
227
286
|
if bs != raw_bs:
|
228
|
-
self.seq_lens.
|
229
|
-
self.position_ids_offsets.
|
287
|
+
self.seq_lens.zero_()
|
288
|
+
self.position_ids_offsets.fill_(1)
|
230
289
|
self.out_cache_loc.zero_()
|
231
290
|
|
232
291
|
# Common inputs
|
@@ -246,25 +305,35 @@ class CudaGraphRunner:
|
|
246
305
|
self.flashinfer_handlers[bs],
|
247
306
|
)
|
248
307
|
|
308
|
+
# Sampling inputs
|
309
|
+
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
310
|
+
|
249
311
|
# Replay
|
312
|
+
torch.cuda.synchronize()
|
250
313
|
self.graphs[bs].replay()
|
251
|
-
|
314
|
+
torch.cuda.synchronize()
|
315
|
+
sample_output, logits_output = self.output_buffers[bs]
|
252
316
|
|
253
317
|
# Unpad
|
254
318
|
if bs != raw_bs:
|
255
|
-
|
256
|
-
next_token_logits=
|
319
|
+
logits_output = LogitsProcessorOutput(
|
320
|
+
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
257
321
|
next_token_logprobs=None,
|
258
322
|
normalized_prompt_logprobs=None,
|
259
323
|
input_token_logprobs=None,
|
260
324
|
input_top_logprobs=None,
|
261
325
|
output_top_logprobs=None,
|
262
326
|
)
|
327
|
+
sample_output = SampleOutput(
|
328
|
+
sample_output.success[:raw_bs],
|
329
|
+
sample_output.probs[:raw_bs],
|
330
|
+
sample_output.batch_next_token_ids[:raw_bs],
|
331
|
+
)
|
263
332
|
|
264
333
|
# Extract logprobs
|
265
334
|
if batch.return_logprob:
|
266
|
-
|
267
|
-
|
335
|
+
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
336
|
+
logits_output.next_token_logits, dim=-1
|
268
337
|
)
|
269
338
|
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
270
339
|
if return_top_logprob:
|
@@ -272,8 +341,8 @@ class CudaGraphRunner:
|
|
272
341
|
forward_mode=ForwardMode.DECODE,
|
273
342
|
top_logprobs_nums=batch.top_logprobs_nums,
|
274
343
|
)
|
275
|
-
|
276
|
-
|
344
|
+
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
345
|
+
logits_output.next_token_logprobs, logits_metadata
|
277
346
|
)[1]
|
278
347
|
|
279
|
-
return
|
348
|
+
return sample_output, logits_output
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
|
26
28
|
|
27
29
|
if TYPE_CHECKING:
|
28
30
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
31
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
29
32
|
|
30
33
|
|
31
34
|
class ForwardMode(IntEnum):
|
@@ -42,6 +45,7 @@ class InputMetadata:
|
|
42
45
|
"""Store all inforamtion of a forward pass."""
|
43
46
|
|
44
47
|
forward_mode: ForwardMode
|
48
|
+
sampling_info: SamplingBatchInfo
|
45
49
|
batch_size: int
|
46
50
|
req_pool_indices: torch.Tensor
|
47
51
|
seq_lens: torch.Tensor
|
@@ -61,9 +65,11 @@ class InputMetadata:
|
|
61
65
|
extend_start_loc: torch.Tensor = None
|
62
66
|
extend_no_prefix: bool = None
|
63
67
|
|
64
|
-
#
|
68
|
+
# For logprob
|
65
69
|
return_logprob: bool = False
|
66
70
|
top_logprobs_nums: List[int] = None
|
71
|
+
extend_seq_lens_cpu: List[int] = None
|
72
|
+
logprob_start_lens_cpu: List[int] = None
|
67
73
|
|
68
74
|
# For multimodal
|
69
75
|
pixel_values: List[torch.Tensor] = None
|
@@ -86,14 +92,19 @@ class InputMetadata:
|
|
86
92
|
reqs = batch.reqs
|
87
93
|
self.pixel_values = [r.pixel_values for r in reqs]
|
88
94
|
self.image_sizes = [r.image_size for r in reqs]
|
89
|
-
self.image_offsets = [
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
95
|
+
self.image_offsets = []
|
96
|
+
for r in reqs:
|
97
|
+
if isinstance(r.image_offset, list):
|
98
|
+
self.image_offsets.append(
|
99
|
+
[
|
100
|
+
(image_offset - len(r.prefix_indices))
|
101
|
+
for image_offset in r.image_offset
|
102
|
+
]
|
103
|
+
)
|
104
|
+
elif isinstance(r.image_offset, int):
|
105
|
+
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
|
106
|
+
elif r.image_offset is None:
|
107
|
+
self.image_offsets.append(0)
|
97
108
|
|
98
109
|
def compute_positions(self, batch: ScheduleBatch):
|
99
110
|
position_ids_offsets = batch.position_ids_offsets
|
@@ -109,8 +120,8 @@ class InputMetadata:
|
|
109
120
|
self.positions = torch.tensor(
|
110
121
|
np.concatenate(
|
111
122
|
[
|
112
|
-
np.arange(
|
113
|
-
for req in batch.reqs
|
123
|
+
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
124
|
+
for i, req in enumerate(batch.reqs)
|
114
125
|
],
|
115
126
|
axis=0,
|
116
127
|
),
|
@@ -123,7 +134,7 @@ class InputMetadata:
|
|
123
134
|
np.concatenate(
|
124
135
|
[
|
125
136
|
np.arange(
|
126
|
-
|
137
|
+
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
127
138
|
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
128
139
|
)
|
129
140
|
for i, req in enumerate(batch.reqs)
|
@@ -139,14 +150,29 @@ class InputMetadata:
|
|
139
150
|
def compute_extend_infos(self, batch: ScheduleBatch):
|
140
151
|
if self.forward_mode == ForwardMode.DECODE:
|
141
152
|
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
153
|
+
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
|
142
154
|
else:
|
143
155
|
extend_lens_cpu = [
|
144
|
-
len(r.fill_ids) -
|
156
|
+
len(r.fill_ids) - batch.prefix_lens_cpu[i]
|
157
|
+
for i, r in enumerate(batch.reqs)
|
145
158
|
]
|
146
159
|
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
147
160
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
148
161
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
149
|
-
self.extend_no_prefix = all(
|
162
|
+
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
163
|
+
|
164
|
+
self.extend_seq_lens_cpu = extend_lens_cpu
|
165
|
+
self.logprob_start_lens_cpu = [
|
166
|
+
(
|
167
|
+
min(
|
168
|
+
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
169
|
+
extend_lens_cpu[i] - 1,
|
170
|
+
)
|
171
|
+
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
172
|
+
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
173
|
+
)
|
174
|
+
for i, req in enumerate(batch.reqs)
|
175
|
+
]
|
150
176
|
|
151
177
|
@classmethod
|
152
178
|
def from_schedule_batch(
|
@@ -157,6 +183,7 @@ class InputMetadata:
|
|
157
183
|
):
|
158
184
|
ret = cls(
|
159
185
|
forward_mode=forward_mode,
|
186
|
+
sampling_info=batch.sampling_info,
|
160
187
|
batch_size=batch.batch_size(),
|
161
188
|
req_pool_indices=batch.req_pool_indices,
|
162
189
|
seq_lens=batch.seq_lens,
|
@@ -167,6 +194,8 @@ class InputMetadata:
|
|
167
194
|
top_logprobs_nums=batch.top_logprobs_nums,
|
168
195
|
)
|
169
196
|
|
197
|
+
ret.sampling_info.prepare_penalties()
|
198
|
+
|
170
199
|
ret.compute_positions(batch)
|
171
200
|
|
172
201
|
ret.compute_extend_infos(batch)
|
@@ -180,44 +209,47 @@ class InputMetadata:
|
|
180
209
|
if forward_mode != ForwardMode.DECODE:
|
181
210
|
ret.init_multimuldal_info(batch)
|
182
211
|
|
183
|
-
prefix_lens = None
|
184
|
-
if forward_mode != ForwardMode.DECODE:
|
185
|
-
prefix_lens = torch.tensor(
|
186
|
-
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
|
187
|
-
)
|
188
|
-
|
189
212
|
if model_runner.server_args.disable_flashinfer:
|
190
|
-
ret.init_triton_args(batch
|
213
|
+
ret.init_triton_args(batch)
|
191
214
|
|
192
215
|
flashinfer_use_ragged = False
|
193
216
|
if not model_runner.server_args.disable_flashinfer:
|
194
217
|
if (
|
195
218
|
forward_mode != ForwardMode.DECODE
|
196
219
|
and int(torch.sum(ret.seq_lens)) > 4096
|
220
|
+
and model_runner.sliding_window_size is None
|
197
221
|
):
|
198
222
|
flashinfer_use_ragged = True
|
199
223
|
ret.init_flashinfer_handlers(
|
200
|
-
model_runner,
|
224
|
+
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
|
201
225
|
)
|
202
226
|
|
203
227
|
return ret
|
204
228
|
|
205
|
-
def init_triton_args(self, batch: ScheduleBatch
|
229
|
+
def init_triton_args(self, batch: ScheduleBatch):
|
206
230
|
"""Init auxiliary variables for triton attention backend."""
|
207
231
|
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
208
|
-
self.triton_prefix_lens = prefix_lens
|
209
232
|
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
210
233
|
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
211
234
|
|
212
235
|
if self.forward_mode == ForwardMode.DECODE:
|
213
236
|
self.triton_max_extend_len = None
|
214
237
|
else:
|
215
|
-
|
238
|
+
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
239
|
+
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
|
216
240
|
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
217
241
|
|
218
242
|
def init_flashinfer_handlers(
|
219
|
-
self,
|
243
|
+
self,
|
244
|
+
model_runner,
|
245
|
+
prefix_lens_cpu,
|
246
|
+
flashinfer_use_ragged,
|
220
247
|
):
|
248
|
+
if self.forward_mode != ForwardMode.DECODE:
|
249
|
+
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
250
|
+
else:
|
251
|
+
prefix_lens = None
|
252
|
+
|
221
253
|
update_flashinfer_indices(
|
222
254
|
self.forward_mode,
|
223
255
|
model_runner,
|
@@ -255,65 +287,139 @@ def update_flashinfer_indices(
|
|
255
287
|
head_dim = model_runner.model_config.head_dim
|
256
288
|
batch_size = len(req_pool_indices)
|
257
289
|
|
258
|
-
if
|
259
|
-
paged_kernel_lens = prefix_lens
|
260
|
-
else:
|
261
|
-
paged_kernel_lens = seq_lens
|
262
|
-
|
263
|
-
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
264
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
265
|
-
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
266
|
-
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
267
|
-
kv_indices = torch.cat(
|
268
|
-
[
|
269
|
-
model_runner.req_to_token_pool.req_to_token[
|
270
|
-
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
271
|
-
]
|
272
|
-
for i in range(batch_size)
|
273
|
-
],
|
274
|
-
dim=0,
|
275
|
-
).contiguous()
|
276
|
-
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
277
|
-
|
278
|
-
if forward_mode == ForwardMode.DECODE:
|
279
|
-
# CUDA graph uses different flashinfer_decode_wrapper
|
280
|
-
if flashinfer_decode_wrapper is None:
|
281
|
-
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
282
|
-
|
283
|
-
flashinfer_decode_wrapper.end_forward()
|
284
|
-
flashinfer_decode_wrapper.begin_forward(
|
285
|
-
kv_indptr,
|
286
|
-
kv_indices,
|
287
|
-
kv_last_page_len,
|
288
|
-
num_qo_heads,
|
289
|
-
num_kv_heads,
|
290
|
-
head_dim,
|
291
|
-
1,
|
292
|
-
)
|
293
|
-
else:
|
294
|
-
# extend part
|
295
|
-
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
296
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
297
|
-
|
290
|
+
if model_runner.sliding_window_size is None:
|
298
291
|
if flashinfer_use_ragged:
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
292
|
+
paged_kernel_lens = prefix_lens
|
293
|
+
else:
|
294
|
+
paged_kernel_lens = seq_lens
|
295
|
+
|
296
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
297
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
298
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
299
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
300
|
+
kv_indices = torch.cat(
|
301
|
+
[
|
302
|
+
model_runner.req_to_token_pool.req_to_token[
|
303
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
304
|
+
]
|
305
|
+
for i in range(batch_size)
|
306
|
+
],
|
307
|
+
dim=0,
|
308
|
+
).contiguous()
|
309
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
310
|
+
|
311
|
+
if forward_mode == ForwardMode.DECODE:
|
312
|
+
# CUDA graph uses different flashinfer_decode_wrapper
|
313
|
+
if flashinfer_decode_wrapper is None:
|
314
|
+
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
315
|
+
|
316
|
+
flashinfer_decode_wrapper.end_forward()
|
317
|
+
flashinfer_decode_wrapper.begin_forward(
|
318
|
+
kv_indptr,
|
319
|
+
kv_indices,
|
320
|
+
kv_last_page_len,
|
303
321
|
num_qo_heads,
|
304
322
|
num_kv_heads,
|
305
323
|
head_dim,
|
324
|
+
1,
|
325
|
+
data_type=model_runner.kv_cache_dtype,
|
326
|
+
q_data_type=model_runner.dtype,
|
306
327
|
)
|
328
|
+
else:
|
329
|
+
# extend part
|
330
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
331
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
332
|
+
|
333
|
+
if flashinfer_use_ragged:
|
334
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
335
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
336
|
+
qo_indptr,
|
337
|
+
qo_indptr,
|
338
|
+
num_qo_heads,
|
339
|
+
num_kv_heads,
|
340
|
+
head_dim,
|
341
|
+
)
|
307
342
|
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
343
|
+
# cached part
|
344
|
+
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
345
|
+
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
346
|
+
qo_indptr,
|
347
|
+
kv_indptr,
|
348
|
+
kv_indices,
|
349
|
+
kv_last_page_len,
|
350
|
+
num_qo_heads,
|
351
|
+
num_kv_heads,
|
352
|
+
head_dim,
|
353
|
+
1,
|
354
|
+
)
|
355
|
+
else:
|
356
|
+
# window attention use paged only
|
357
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
358
|
+
for wrapper_id in range(2):
|
359
|
+
if wrapper_id == 0:
|
360
|
+
if forward_mode == ForwardMode.DECODE:
|
361
|
+
paged_kernel_lens = torch.minimum(
|
362
|
+
seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
paged_kernel_lens = torch.minimum(
|
366
|
+
seq_lens,
|
367
|
+
torch.tensor(model_runner.sliding_window_size)
|
368
|
+
+ seq_lens
|
369
|
+
- prefix_lens,
|
370
|
+
)
|
371
|
+
else:
|
372
|
+
paged_kernel_lens = seq_lens
|
373
|
+
|
374
|
+
kv_start_idx = seq_lens - paged_kernel_lens
|
375
|
+
|
376
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
377
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
378
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
379
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
380
|
+
kv_indices = torch.cat(
|
381
|
+
[
|
382
|
+
model_runner.req_to_token_pool.req_to_token[
|
383
|
+
req_pool_indices_cpu[i],
|
384
|
+
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
|
385
|
+
]
|
386
|
+
for i in range(batch_size)
|
387
|
+
],
|
388
|
+
dim=0,
|
389
|
+
).contiguous()
|
390
|
+
|
391
|
+
if forward_mode == ForwardMode.DECODE:
|
392
|
+
# CUDA graph uses different flashinfer_decode_wrapper
|
393
|
+
if flashinfer_decode_wrapper is None:
|
394
|
+
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
395
|
+
|
396
|
+
flashinfer_decode_wrapper[wrapper_id].end_forward()
|
397
|
+
flashinfer_decode_wrapper[wrapper_id].begin_forward(
|
398
|
+
kv_indptr,
|
399
|
+
kv_indices,
|
400
|
+
kv_last_page_len,
|
401
|
+
num_qo_heads,
|
402
|
+
num_kv_heads,
|
403
|
+
head_dim,
|
404
|
+
1,
|
405
|
+
data_type=model_runner.kv_cache_dtype,
|
406
|
+
q_data_type=model_runner.dtype,
|
407
|
+
)
|
408
|
+
else:
|
409
|
+
# extend part
|
410
|
+
qo_indptr = torch.zeros(
|
411
|
+
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
412
|
+
)
|
413
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
414
|
+
|
415
|
+
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
|
416
|
+
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
|
417
|
+
qo_indptr,
|
418
|
+
kv_indptr,
|
419
|
+
kv_indices,
|
420
|
+
kv_last_page_len,
|
421
|
+
num_qo_heads,
|
422
|
+
num_kv_heads,
|
423
|
+
head_dim,
|
424
|
+
1,
|
425
|
+
)
|