sglang 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/backend/runtime_endpoint.py +14 -4
- sglang/bench_latency.py +6 -3
- sglang/global_config.py +22 -16
- sglang/lang/chat_template.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/srt/layers/radix_attention.py +14 -37
- sglang/srt/layers/token_attention.py +2 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/infer_batch.py +256 -42
- sglang/srt/managers/controller/manager_multi.py +6 -2
- sglang/srt/managers/controller/manager_single.py +125 -50
- sglang/srt/managers/controller/model_runner.py +69 -284
- sglang/srt/managers/controller/radix_cache.py +4 -3
- sglang/srt/managers/controller/schedule_heuristic.py +4 -0
- sglang/srt/managers/controller/tp_worker.py +44 -44
- sglang/srt/memory_pool.py +52 -50
- sglang/srt/models/minicpm.py +1 -8
- sglang/srt/models/qwen2_moe.py +126 -107
- sglang/srt/server.py +11 -15
- sglang/srt/server_args.py +12 -4
- sglang/srt/utils.py +1 -1
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/METADATA +9 -1
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/RECORD +27 -26
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -12,7 +12,6 @@ from sglang.utils import http_request
|
|
12
12
|
|
13
13
|
|
14
14
|
class RuntimeEndpoint(BaseBackend):
|
15
|
-
|
16
15
|
def __init__(
|
17
16
|
self,
|
18
17
|
base_url: str,
|
@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
|
|
38
37
|
self.model_info = res.json()
|
39
38
|
|
40
39
|
self.chat_template = get_chat_template_by_model_path(
|
41
|
-
self.model_info["model_path"]
|
40
|
+
self.model_info["model_path"]
|
41
|
+
)
|
42
42
|
|
43
43
|
def get_model_name(self):
|
44
44
|
return self.model_info["model_path"]
|
@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
|
|
124
124
|
else:
|
125
125
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
126
126
|
|
127
|
-
for item in [
|
127
|
+
for item in [
|
128
|
+
"return_logprob",
|
129
|
+
"logprob_start_len",
|
130
|
+
"top_logprobs_num",
|
131
|
+
"return_text_in_logprobs",
|
132
|
+
]:
|
128
133
|
value = getattr(sampling_params, item, None)
|
129
134
|
if value is not None:
|
130
135
|
data[item] = value
|
@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
|
|
171
176
|
else:
|
172
177
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
173
178
|
|
174
|
-
for item in [
|
179
|
+
for item in [
|
180
|
+
"return_logprob",
|
181
|
+
"logprob_start_len",
|
182
|
+
"top_logprobs_num",
|
183
|
+
"return_text_in_logprobs",
|
184
|
+
]:
|
175
185
|
value = getattr(sampling_params, item, None)
|
176
186
|
if value is not None:
|
177
187
|
data[item] = value
|
sglang/bench_latency.py
CHANGED
@@ -70,6 +70,7 @@ class BenchArgs:
|
|
70
70
|
|
71
71
|
def load_model(server_args, tp_rank):
|
72
72
|
suppress_other_loggers()
|
73
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
73
74
|
|
74
75
|
model_config = ModelConfig(path=server_args.model_path)
|
75
76
|
model_runner = ModelRunner(
|
@@ -81,7 +82,7 @@ def load_model(server_args, tp_rank):
|
|
81
82
|
nccl_port=28888,
|
82
83
|
server_args=server_args,
|
83
84
|
)
|
84
|
-
|
85
|
+
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
85
86
|
tokenizer = get_tokenizer(
|
86
87
|
server_args.tokenizer_path,
|
87
88
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -201,7 +202,7 @@ def correctness_test(
|
|
201
202
|
|
202
203
|
# Print
|
203
204
|
for i in range(len(reqs)):
|
204
|
-
|
205
|
+
rank_print(tokenizer.decode(output_ids[i]))
|
205
206
|
|
206
207
|
|
207
208
|
def latency_test(
|
@@ -213,7 +214,7 @@ def latency_test(
|
|
213
214
|
|
214
215
|
# Load the model
|
215
216
|
model_runner, tokenizer = load_model(server_args, tp_rank)
|
216
|
-
|
217
|
+
rank_print(
|
217
218
|
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
218
219
|
)
|
219
220
|
|
@@ -299,6 +300,8 @@ def main(server_args, bench_args):
|
|
299
300
|
for proc in workers:
|
300
301
|
proc.join()
|
301
302
|
|
303
|
+
proc.terminate()
|
304
|
+
|
302
305
|
|
303
306
|
if __name__ == "__main__":
|
304
307
|
parser = argparse.ArgumentParser()
|
sglang/global_config.py
CHANGED
@@ -8,36 +8,42 @@ class GlobalConfig:
|
|
8
8
|
# 2: output final text after every run
|
9
9
|
self.verbosity = 0
|
10
10
|
|
11
|
+
# Default backend of the language
|
11
12
|
self.default_backend = None
|
12
13
|
|
13
|
-
#
|
14
|
+
# Runtime constants: Request dependency time due to network delay
|
15
|
+
self.request_dependency_delay = 0.02
|
16
|
+
self.wait_for_new_request_delay = 0.0006
|
17
|
+
|
18
|
+
# Runtime constants: New generation token ratio estimation
|
19
|
+
self.base_new_token_ratio = 0.4
|
20
|
+
self.base_min_new_token_ratio = 0.2
|
21
|
+
self.new_token_ratio_decay = 0.0001
|
22
|
+
self.new_token_ratio_recovery = 0.05
|
23
|
+
|
24
|
+
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
25
|
+
# This can improve the speed for large batch sizes during prefill.
|
26
|
+
self.layer_sync_threshold = 8192
|
27
|
+
|
28
|
+
# Runtime constants: others
|
29
|
+
self.num_continue_decode_steps = 10
|
30
|
+
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
31
|
+
|
32
|
+
# Output tokenization configs
|
14
33
|
self.skip_special_tokens_in_output = True
|
15
34
|
self.spaces_between_special_tokens_in_out = True
|
16
35
|
|
17
|
-
#
|
36
|
+
# Interpreter optimization configs
|
18
37
|
self.eager_fill_image = False
|
19
38
|
self.enable_precache_with_tracing = True
|
20
39
|
self.enable_parallel_encoding = True
|
21
40
|
self.enable_parallel_decoding = True
|
22
41
|
|
42
|
+
# Deprecated
|
23
43
|
# Choices: ["no_adjust", "adjust_cache"]
|
24
44
|
# no_adjust: Do not adjust the position embedding of KV cache.
|
25
45
|
# adjust_cache: Adjust the position embedding of KV cache.
|
26
46
|
self.concate_and_append_mode = "no_adjust"
|
27
47
|
|
28
|
-
# Request dependency time due to network delay
|
29
|
-
self.request_dependency_delay = 0.02
|
30
|
-
self.wait_for_new_request_delay = 0.0006
|
31
|
-
|
32
|
-
# New generation token ratio estimation
|
33
|
-
self.base_new_token_ratio = 0.4
|
34
|
-
self.base_min_new_token_ratio = 0.2
|
35
|
-
self.new_token_ratio_decay = 0.0001
|
36
|
-
self.new_token_ratio_recovery = 0.05
|
37
|
-
|
38
|
-
# The threshold (number of tokens) to trigger layer-wise cuda sync.
|
39
|
-
# This can improve the speed for large batch sizes during prefill.
|
40
|
-
self.layer_sync_threshold = 8192
|
41
|
-
|
42
48
|
|
43
49
|
global_config = GlobalConfig()
|
sglang/lang/chat_template.py
CHANGED
@@ -84,7 +84,7 @@ register_chat_template(
|
|
84
84
|
"system": ("SYSTEM:", "\n"),
|
85
85
|
"user": ("USER:", "\n"),
|
86
86
|
"assistant": ("ASSISTANT:", "\n"),
|
87
|
-
}
|
87
|
+
},
|
88
88
|
)
|
89
89
|
)
|
90
90
|
|
@@ -177,7 +177,7 @@ register_chat_template(
|
|
177
177
|
"assistant": ("", "<|im_end|>\n"),
|
178
178
|
},
|
179
179
|
style=ChatTemplateStyle.PLAIN,
|
180
|
-
stop_str=("<|im_end|>",)
|
180
|
+
stop_str=("<|im_end|>",),
|
181
181
|
)
|
182
182
|
)
|
183
183
|
|
sglang/lang/ir.py
CHANGED
@@ -24,9 +24,9 @@ class SglSamplingParams:
|
|
24
24
|
presence_penalty: float = 0.0
|
25
25
|
ignore_eos: bool = False
|
26
26
|
return_logprob: Optional[bool] = None
|
27
|
-
logprob_start_len: Optional[int] = None,
|
28
|
-
top_logprobs_num: Optional[int] = None,
|
29
|
-
return_text_in_logprobs: Optional[bool] = None,
|
27
|
+
logprob_start_len: Optional[int] = (None,)
|
28
|
+
top_logprobs_num: Optional[int] = (None,)
|
29
|
+
return_text_in_logprobs: Optional[bool] = (None,)
|
30
30
|
|
31
31
|
# for constrained generation, not included in to_xxx_kwargs
|
32
32
|
dtype: Optional[str] = None
|
@@ -1,6 +1,5 @@
|
|
1
1
|
"""Radix attention."""
|
2
2
|
|
3
|
-
import numpy as np
|
4
3
|
import torch
|
5
4
|
from flashinfer.cascade import merge_state
|
6
5
|
from torch import nn
|
@@ -8,6 +7,7 @@ from torch import nn
|
|
8
7
|
from sglang.global_config import global_config
|
9
8
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
10
9
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
10
|
+
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
11
11
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
12
12
|
|
13
13
|
|
@@ -29,24 +29,14 @@ class RadixAttention(nn.Module):
|
|
29
29
|
self.scaling = scaling
|
30
30
|
self.layer_id = layer_id
|
31
31
|
|
32
|
-
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
33
|
-
|
34
32
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
35
|
-
self.
|
36
|
-
self.extend_forward = self.prefill_forward_flashinfer
|
33
|
+
self.extend_forward = self.extend_forward_flashinfer
|
37
34
|
self.decode_forward = self.decode_forward_flashinfer
|
38
|
-
# flashinfer now accepts float logit_cap argument
|
39
|
-
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
40
35
|
else:
|
41
|
-
self.prefill_forward = self.prefill_forward_triton
|
42
36
|
self.extend_forward = self.extend_forward_triton
|
43
37
|
self.decode_forward = self.decode_forward_triton
|
44
|
-
self.logit_cap = logit_cap if logit_cap is not None else 0
|
45
38
|
|
46
|
-
|
47
|
-
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
|
48
|
-
# See the extend_forward_xxx functions.
|
49
|
-
raise NotImplementedError()
|
39
|
+
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
50
40
|
|
51
41
|
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
52
42
|
o = torch.empty_like(q)
|
@@ -60,13 +50,13 @@ class RadixAttention(nn.Module):
|
|
60
50
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
61
51
|
input_metadata.req_to_token_pool.req_to_token,
|
62
52
|
input_metadata.req_pool_indices,
|
63
|
-
input_metadata.
|
53
|
+
input_metadata.triton_start_loc,
|
64
54
|
input_metadata.seq_lens,
|
65
|
-
input_metadata.
|
55
|
+
input_metadata.triton_prefix_lens,
|
66
56
|
input_metadata.extend_start_loc,
|
67
57
|
input_metadata.extend_seq_lens,
|
68
|
-
input_metadata.
|
69
|
-
input_metadata.
|
58
|
+
input_metadata.triton_max_seq_len,
|
59
|
+
input_metadata.triton_max_extend_len,
|
70
60
|
sm_scale=self.scaling,
|
71
61
|
logit_cap=self.logit_cap,
|
72
62
|
)
|
@@ -84,10 +74,9 @@ class RadixAttention(nn.Module):
|
|
84
74
|
o.view(-1, self.tp_q_head_num, self.head_dim),
|
85
75
|
input_metadata.req_to_token_pool.req_to_token,
|
86
76
|
input_metadata.req_pool_indices,
|
87
|
-
input_metadata.
|
77
|
+
input_metadata.triton_start_loc,
|
88
78
|
input_metadata.seq_lens,
|
89
|
-
input_metadata.
|
90
|
-
input_metadata.other_kv_index,
|
79
|
+
input_metadata.triton_max_seq_len,
|
91
80
|
input_metadata.total_num_tokens,
|
92
81
|
sm_scale=self.scaling,
|
93
82
|
logit_cap=self.logit_cap,
|
@@ -95,7 +84,7 @@ class RadixAttention(nn.Module):
|
|
95
84
|
|
96
85
|
return o
|
97
86
|
|
98
|
-
def
|
87
|
+
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
99
88
|
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
100
89
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
101
90
|
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
@@ -105,7 +94,7 @@ class RadixAttention(nn.Module):
|
|
105
94
|
logits_soft_cap=self.logit_cap,
|
106
95
|
)
|
107
96
|
|
108
|
-
if input_metadata.
|
97
|
+
if input_metadata.extend_no_prefix:
|
109
98
|
o = o1
|
110
99
|
else:
|
111
100
|
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
@@ -141,25 +130,13 @@ class RadixAttention(nn.Module):
|
|
141
130
|
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
142
131
|
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
143
132
|
|
144
|
-
if input_metadata.forward_mode == ForwardMode.
|
145
|
-
return self.prefill_forward(q, k, v, input_metadata)
|
146
|
-
elif input_metadata.forward_mode == ForwardMode.EXTEND:
|
133
|
+
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
147
134
|
return self.extend_forward(q, k, v, input_metadata)
|
148
135
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
149
136
|
return self.decode_forward(q, k, v, input_metadata)
|
150
137
|
|
151
138
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
152
139
|
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
140
|
+
key_buffer[input_metadata.out_cache_loc] = cache_k
|
153
141
|
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
154
|
-
|
155
|
-
key_buffer[input_metadata.out_cache_loc] = cache_k
|
156
|
-
value_buffer[input_metadata.out_cache_loc] = cache_v
|
157
|
-
elif input_metadata.out_cache_cont_start is not None:
|
158
|
-
key_buffer[
|
159
|
-
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
160
|
-
] = cache_k
|
161
|
-
value_buffer[
|
162
|
-
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
163
|
-
] = cache_v
|
164
|
-
else:
|
165
|
-
raise RuntimeError()
|
142
|
+
value_buffer[input_metadata.out_cache_loc] = cache_v
|
@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
|
|
107
107
|
stride_obs,
|
108
108
|
stride_oh,
|
109
109
|
stride_req_to_token_b,
|
110
|
-
other_kv_index, # To fix a NAN issue
|
111
110
|
kv_group_num: tl.constexpr,
|
112
111
|
BLOCK_DMODEL: tl.constexpr,
|
113
112
|
BLOCK_N: tl.constexpr,
|
@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
|
|
138
137
|
+ cur_batch_req_idx * stride_req_to_token_b
|
139
138
|
+ (start_n + offs_n),
|
140
139
|
mask=(start_n + offs_n) < cur_batch_seq_len,
|
141
|
-
other=
|
140
|
+
other=0,
|
142
141
|
)
|
143
142
|
|
144
143
|
qk = tl.load(
|
@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
|
|
250
249
|
b_req_idx,
|
251
250
|
b_start_loc,
|
252
251
|
b_seq_len,
|
253
|
-
other_kv_index,
|
254
252
|
):
|
255
253
|
BLOCK = 64
|
256
254
|
batch, head = b_seq_len.shape[0], logics.shape[0]
|
@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
|
|
277
275
|
o.stride(0),
|
278
276
|
o.stride(1),
|
279
277
|
req_to_tokens.stride(0),
|
280
|
-
other_kv_index,
|
281
278
|
)
|
282
279
|
return
|
283
280
|
|
@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
|
|
295
292
|
o.stride(0),
|
296
293
|
o.stride(1),
|
297
294
|
req_to_tokens.stride(0),
|
298
|
-
other_kv_index,
|
299
295
|
kv_group_num=kv_group_num,
|
300
296
|
BLOCK_DMODEL=v_buffer.shape[-1],
|
301
297
|
BLOCK_N=BLOCK,
|
@@ -315,9 +311,8 @@ def token_attention_fwd(
|
|
315
311
|
b_start_loc,
|
316
312
|
b_seq_len,
|
317
313
|
max_len_in_batch,
|
318
|
-
other_kv_index,
|
319
314
|
total_num_tokens,
|
320
|
-
sm_scale
|
315
|
+
sm_scale,
|
321
316
|
logit_cap=-1,
|
322
317
|
att_m=None,
|
323
318
|
):
|
@@ -325,7 +320,6 @@ def token_attention_fwd(
|
|
325
320
|
att_m = torch.empty(
|
326
321
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
327
322
|
)
|
328
|
-
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
329
323
|
|
330
324
|
_token_att_m_fwd(
|
331
325
|
q,
|
@@ -347,5 +341,4 @@ def token_attention_fwd(
|
|
347
341
|
b_req_idx,
|
348
342
|
b_start_loc,
|
349
343
|
b_seq_len,
|
350
|
-
other_kv_index,
|
351
344
|
)
|
@@ -0,0 +1,196 @@
|
|
1
|
+
"""Run the model with cuda graph."""
|
2
|
+
|
3
|
+
import bisect
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from vllm.distributed.parallel_state import graph_capture
|
7
|
+
|
8
|
+
from sglang.global_config import global_config
|
9
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
10
|
+
from sglang.srt.managers.controller.infer_batch import (
|
11
|
+
Batch,
|
12
|
+
ForwardMode,
|
13
|
+
InputMetadata,
|
14
|
+
init_flashinfer_args,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
class CudaGraphRunner:
|
19
|
+
def __init__(self, model_runner, max_batch_size_to_capture):
|
20
|
+
self.model_runner = model_runner
|
21
|
+
self.graphs = {}
|
22
|
+
self.input_buffers = {}
|
23
|
+
self.output_buffers = {}
|
24
|
+
self.flashinfer_handlers = {}
|
25
|
+
self.graph_memory_pool = None
|
26
|
+
|
27
|
+
# Common inputs
|
28
|
+
self.max_bs = max_batch_size_to_capture
|
29
|
+
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
30
|
+
self.req_pool_indices = torch.zeros(
|
31
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
32
|
+
)
|
33
|
+
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
34
|
+
self.position_ids_offsets = torch.zeros(
|
35
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
36
|
+
)
|
37
|
+
self.out_cache_loc = torch.zeros(
|
38
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
39
|
+
)
|
40
|
+
|
41
|
+
# FlashInfer inputs
|
42
|
+
self.flashinfer_workspace_buffer = (
|
43
|
+
self.model_runner.flashinfer_workspace_buffers[0]
|
44
|
+
)
|
45
|
+
self.flashinfer_kv_indptr = torch.zeros(
|
46
|
+
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
47
|
+
)
|
48
|
+
self.flashinfer_kv_indices = torch.zeros(
|
49
|
+
(self.max_bs * model_runner.model_config.context_len,),
|
50
|
+
dtype=torch.int32,
|
51
|
+
device="cuda",
|
52
|
+
)
|
53
|
+
self.flashinfer_kv_last_page_len = torch.ones(
|
54
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
55
|
+
)
|
56
|
+
|
57
|
+
def can_run(self, batch_size):
|
58
|
+
return batch_size < self.max_bs
|
59
|
+
|
60
|
+
def capture(self, batch_size_list):
|
61
|
+
self.batch_size_list = batch_size_list
|
62
|
+
with graph_capture() as graph_capture_context:
|
63
|
+
self.stream = graph_capture_context.stream
|
64
|
+
for bs in batch_size_list:
|
65
|
+
(
|
66
|
+
graph,
|
67
|
+
input_buffers,
|
68
|
+
output_buffers,
|
69
|
+
flashinfer_handler,
|
70
|
+
) = self.capture_one_batch_size(bs)
|
71
|
+
self.graphs[bs] = graph
|
72
|
+
self.input_buffers[bs] = input_buffers
|
73
|
+
self.output_buffers[bs] = output_buffers
|
74
|
+
self.flashinfer_handlers[bs] = flashinfer_handler
|
75
|
+
|
76
|
+
def capture_one_batch_size(self, bs):
|
77
|
+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
78
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
79
|
+
|
80
|
+
graph = torch.cuda.CUDAGraph()
|
81
|
+
stream = self.stream
|
82
|
+
|
83
|
+
# Common inputs
|
84
|
+
input_ids = self.input_ids[:bs]
|
85
|
+
req_pool_indices = self.req_pool_indices[:bs]
|
86
|
+
seq_lens = self.seq_lens[:bs]
|
87
|
+
position_ids_offsets = self.position_ids_offsets[:bs]
|
88
|
+
out_cache_loc = self.out_cache_loc[:bs]
|
89
|
+
|
90
|
+
# FlashInfer inputs
|
91
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
92
|
+
self.model_runner.model_config.num_attention_heads
|
93
|
+
// self.model_runner.tp_size,
|
94
|
+
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
95
|
+
):
|
96
|
+
use_tensor_cores = True
|
97
|
+
else:
|
98
|
+
use_tensor_cores = False
|
99
|
+
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
100
|
+
self.flashinfer_workspace_buffer,
|
101
|
+
"NHD",
|
102
|
+
use_cuda_graph=True,
|
103
|
+
use_tensor_cores=use_tensor_cores,
|
104
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
105
|
+
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
106
|
+
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
107
|
+
)
|
108
|
+
init_flashinfer_args(
|
109
|
+
ForwardMode.DECODE,
|
110
|
+
self.model_runner,
|
111
|
+
req_pool_indices,
|
112
|
+
seq_lens,
|
113
|
+
None,
|
114
|
+
flashinfer_decode_wrapper,
|
115
|
+
)
|
116
|
+
|
117
|
+
# Run and capture
|
118
|
+
def run_once():
|
119
|
+
input_metadata = InputMetadata.create(
|
120
|
+
self.model_runner,
|
121
|
+
forward_mode=ForwardMode.DECODE,
|
122
|
+
req_pool_indices=req_pool_indices,
|
123
|
+
seq_lens=seq_lens,
|
124
|
+
prefix_lens=None,
|
125
|
+
position_ids_offsets=position_ids_offsets,
|
126
|
+
out_cache_loc=out_cache_loc,
|
127
|
+
return_logprob=False,
|
128
|
+
top_logprobs_nums=0,
|
129
|
+
skip_flashinfer_init=True,
|
130
|
+
)
|
131
|
+
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
|
132
|
+
return self.model_runner.model.forward(
|
133
|
+
input_ids, input_metadata.positions, input_metadata
|
134
|
+
)
|
135
|
+
|
136
|
+
for _ in range(2):
|
137
|
+
run_once()
|
138
|
+
|
139
|
+
torch.cuda.synchronize()
|
140
|
+
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
141
|
+
out = run_once()
|
142
|
+
torch.cuda.synchronize()
|
143
|
+
self.graph_memory_pool = graph.pool()
|
144
|
+
return graph, None, out, flashinfer_decode_wrapper
|
145
|
+
|
146
|
+
def replay(self, batch: Batch):
|
147
|
+
assert batch.out_cache_loc is not None
|
148
|
+
assert not batch.return_logprob
|
149
|
+
raw_bs = len(batch.reqs)
|
150
|
+
|
151
|
+
# Pad
|
152
|
+
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
153
|
+
bs = self.batch_size_list[index]
|
154
|
+
if bs != raw_bs:
|
155
|
+
self.seq_lens.zero_()
|
156
|
+
self.position_ids_offsets.fill_(1)
|
157
|
+
self.out_cache_loc.zero_()
|
158
|
+
|
159
|
+
# Common inputs
|
160
|
+
self.input_ids[:raw_bs] = batch.input_ids
|
161
|
+
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
|
162
|
+
self.seq_lens[:raw_bs] = batch.seq_lens
|
163
|
+
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
164
|
+
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
165
|
+
|
166
|
+
# FlashInfer inputs
|
167
|
+
init_flashinfer_args(
|
168
|
+
ForwardMode.DECODE,
|
169
|
+
self.model_runner,
|
170
|
+
self.req_pool_indices[:bs],
|
171
|
+
self.seq_lens[:bs],
|
172
|
+
None,
|
173
|
+
self.flashinfer_handlers[bs],
|
174
|
+
)
|
175
|
+
|
176
|
+
# Replay
|
177
|
+
self.graphs[bs].replay()
|
178
|
+
output = self.output_buffers[bs]
|
179
|
+
|
180
|
+
# Unpad
|
181
|
+
if bs == raw_bs:
|
182
|
+
return output
|
183
|
+
else:
|
184
|
+
output = LogitProcessorOutput(
|
185
|
+
next_token_logits=output.next_token_logits[:raw_bs],
|
186
|
+
next_token_logprobs=output.next_token_logprobs[:raw_bs]
|
187
|
+
if output.next_token_logprobs is not None
|
188
|
+
else None,
|
189
|
+
normalized_prompt_logprobs=None,
|
190
|
+
prefill_token_logprobs=None,
|
191
|
+
prefill_top_logprobs=None,
|
192
|
+
decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
|
193
|
+
if output.decode_top_logprobs is not None
|
194
|
+
else None,
|
195
|
+
)
|
196
|
+
return output
|