sglang 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +26 -0
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +40 -18
- sglang/global_config.py +21 -16
- sglang/lang/chat_template.py +41 -6
- sglang/lang/interpreter.py +5 -1
- sglang/lang/ir.py +61 -25
- sglang/srt/constrained/__init__.py +3 -2
- sglang/srt/hf_transformers_utils.py +7 -3
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +181 -167
- sglang/srt/layers/logits_processor.py +55 -19
- sglang/srt/layers/radix_attention.py +33 -59
- sglang/srt/layers/token_attention.py +4 -8
- sglang/srt/managers/controller/cuda_graph_runner.py +172 -0
- sglang/srt/managers/controller/infer_batch.py +244 -36
- sglang/srt/managers/controller/manager_single.py +1 -1
- sglang/srt/managers/controller/model_runner.py +69 -284
- sglang/srt/managers/controller/tp_worker.py +39 -20
- sglang/srt/managers/detokenizer_manager.py +4 -2
- sglang/srt/managers/io_struct.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/memory_pool.py +33 -6
- sglang/srt/model_config.py +6 -0
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/llama2.py +3 -3
- sglang/srt/models/llama_classification.py +10 -7
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/openai_api_adapter.py +2 -2
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +18 -8
- sglang/srt/server_args.py +24 -20
- sglang/srt/utils.py +68 -35
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/METADATA +19 -13
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/RECORD +40 -36
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/WHEEL +1 -1
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/LICENSE +0 -0
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/top_level.txt +0 -0
@@ -1,60 +1,42 @@
|
|
1
1
|
"""Radix attention."""
|
2
2
|
|
3
|
-
import numpy as np
|
4
3
|
import torch
|
4
|
+
from flashinfer.cascade import merge_state
|
5
5
|
from torch import nn
|
6
6
|
|
7
7
|
from sglang.global_config import global_config
|
8
|
-
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
9
8
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
10
9
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
10
|
+
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
11
11
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
12
12
|
|
13
13
|
|
14
14
|
class RadixAttention(nn.Module):
|
15
15
|
def __init__(
|
16
|
-
self,
|
17
|
-
|
16
|
+
self,
|
17
|
+
num_heads: int,
|
18
|
+
head_dim: int,
|
19
|
+
scaling: float,
|
20
|
+
num_kv_heads: int,
|
21
|
+
layer_id: int,
|
22
|
+
logit_cap: int = -1,
|
18
23
|
):
|
19
24
|
super().__init__()
|
20
25
|
self.tp_q_head_num = num_heads
|
21
26
|
self.tp_k_head_num = num_kv_heads
|
22
27
|
self.tp_v_head_num = num_kv_heads
|
23
28
|
self.head_dim = head_dim
|
29
|
+
self.scaling = scaling
|
24
30
|
self.layer_id = layer_id
|
25
31
|
|
26
|
-
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
27
|
-
|
28
|
-
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
29
|
-
|
30
32
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
31
|
-
self.
|
32
|
-
self.extend_forward = self.prefill_forward_flashinfer
|
33
|
+
self.extend_forward = self.extend_forward_flashinfer
|
33
34
|
self.decode_forward = self.decode_forward_flashinfer
|
34
|
-
# flashinfer now accepts float logit_cap argument
|
35
|
-
self.logit_cap = logit_cap if logit_cap > 0 else 0
|
36
35
|
else:
|
37
|
-
self.prefill_forward = self.prefill_forward_triton
|
38
36
|
self.extend_forward = self.extend_forward_triton
|
39
37
|
self.decode_forward = self.decode_forward_triton
|
40
|
-
self.logit_cap = logit_cap
|
41
|
-
|
42
|
-
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
43
|
-
o = torch.empty_like(q)
|
44
|
-
|
45
|
-
context_attention_fwd(
|
46
|
-
q.view(-1, self.tp_q_head_num, self.head_dim),
|
47
|
-
k,
|
48
|
-
v,
|
49
|
-
o.view(-1, self.tp_q_head_num, self.head_dim),
|
50
|
-
input_metadata.start_loc,
|
51
|
-
input_metadata.seq_lens,
|
52
|
-
input_metadata.max_seq_len,
|
53
|
-
self.logit_cap,
|
54
|
-
)
|
55
|
-
self.store_kv_cache(k, v, input_metadata)
|
56
38
|
|
57
|
-
|
39
|
+
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
58
40
|
|
59
41
|
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
60
42
|
o = torch.empty_like(q)
|
@@ -68,14 +50,15 @@ class RadixAttention(nn.Module):
|
|
68
50
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
69
51
|
input_metadata.req_to_token_pool.req_to_token,
|
70
52
|
input_metadata.req_pool_indices,
|
71
|
-
input_metadata.
|
53
|
+
input_metadata.triton_start_loc,
|
72
54
|
input_metadata.seq_lens,
|
73
|
-
input_metadata.
|
55
|
+
input_metadata.triton_prefix_lens,
|
74
56
|
input_metadata.extend_start_loc,
|
75
57
|
input_metadata.extend_seq_lens,
|
76
|
-
input_metadata.
|
77
|
-
input_metadata.
|
78
|
-
self.
|
58
|
+
input_metadata.triton_max_seq_len,
|
59
|
+
input_metadata.triton_max_extend_len,
|
60
|
+
sm_scale=self.scaling,
|
61
|
+
logit_cap=self.logit_cap,
|
79
62
|
)
|
80
63
|
|
81
64
|
return o
|
@@ -91,39 +74,41 @@ class RadixAttention(nn.Module):
|
|
91
74
|
o.view(-1, self.tp_q_head_num, self.head_dim),
|
92
75
|
input_metadata.req_to_token_pool.req_to_token,
|
93
76
|
input_metadata.req_pool_indices,
|
94
|
-
input_metadata.
|
77
|
+
input_metadata.triton_start_loc,
|
95
78
|
input_metadata.seq_lens,
|
96
|
-
input_metadata.
|
97
|
-
input_metadata.other_kv_index,
|
79
|
+
input_metadata.triton_max_seq_len,
|
98
80
|
input_metadata.total_num_tokens,
|
99
|
-
self.
|
81
|
+
sm_scale=self.scaling,
|
82
|
+
logit_cap=self.logit_cap,
|
100
83
|
)
|
101
84
|
|
102
85
|
return o
|
103
86
|
|
104
|
-
def
|
105
|
-
self.store_kv_cache(k, v, input_metadata)
|
106
|
-
|
87
|
+
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
107
88
|
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
108
89
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
109
90
|
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
110
91
|
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
92
|
+
causal=True,
|
93
|
+
sm_scale=self.scaling,
|
111
94
|
logits_soft_cap=self.logit_cap,
|
112
95
|
)
|
113
96
|
|
114
|
-
if input_metadata.
|
97
|
+
if input_metadata.extend_no_prefix:
|
115
98
|
o = o1
|
116
99
|
else:
|
117
100
|
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
118
101
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
119
102
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
120
103
|
causal=False,
|
104
|
+
sm_scale=self.scaling,
|
121
105
|
logits_soft_cap=self.logit_cap,
|
122
106
|
)
|
123
107
|
|
124
|
-
from flashinfer.cascade import merge_state
|
125
108
|
o, _ = merge_state(o1, s1, o2, s2)
|
126
109
|
|
110
|
+
self.store_kv_cache(k, v, input_metadata)
|
111
|
+
|
127
112
|
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
128
113
|
torch.cuda.synchronize()
|
129
114
|
|
@@ -135,6 +120,7 @@ class RadixAttention(nn.Module):
|
|
135
120
|
o = input_metadata.flashinfer_decode_wrapper.forward(
|
136
121
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
137
122
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
123
|
+
sm_scale=self.scaling,
|
138
124
|
logits_soft_cap=self.logit_cap,
|
139
125
|
)
|
140
126
|
|
@@ -144,25 +130,13 @@ class RadixAttention(nn.Module):
|
|
144
130
|
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
145
131
|
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
146
132
|
|
147
|
-
if input_metadata.forward_mode == ForwardMode.
|
148
|
-
return self.prefill_forward(q, k, v, input_metadata)
|
149
|
-
elif input_metadata.forward_mode == ForwardMode.EXTEND:
|
133
|
+
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
150
134
|
return self.extend_forward(q, k, v, input_metadata)
|
151
135
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
152
136
|
return self.decode_forward(q, k, v, input_metadata)
|
153
137
|
|
154
138
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
155
139
|
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
140
|
+
key_buffer[input_metadata.out_cache_loc] = cache_k
|
156
141
|
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
157
|
-
|
158
|
-
key_buffer[input_metadata.out_cache_loc] = cache_k
|
159
|
-
value_buffer[input_metadata.out_cache_loc] = cache_v
|
160
|
-
elif input_metadata.out_cache_cont_start is not None:
|
161
|
-
key_buffer[
|
162
|
-
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
163
|
-
] = cache_k
|
164
|
-
value_buffer[
|
165
|
-
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
166
|
-
] = cache_v
|
167
|
-
else:
|
168
|
-
raise RuntimeError()
|
142
|
+
value_buffer[input_metadata.out_cache_loc] = cache_v
|
@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
|
|
107
107
|
stride_obs,
|
108
108
|
stride_oh,
|
109
109
|
stride_req_to_token_b,
|
110
|
-
other_kv_index, # To fix a NAN issue
|
111
110
|
kv_group_num: tl.constexpr,
|
112
111
|
BLOCK_DMODEL: tl.constexpr,
|
113
112
|
BLOCK_N: tl.constexpr,
|
@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
|
|
138
137
|
+ cur_batch_req_idx * stride_req_to_token_b
|
139
138
|
+ (start_n + offs_n),
|
140
139
|
mask=(start_n + offs_n) < cur_batch_seq_len,
|
141
|
-
other=
|
140
|
+
other=0,
|
142
141
|
)
|
143
142
|
|
144
143
|
qk = tl.load(
|
@@ -176,6 +175,7 @@ def _token_att_m_fwd(
|
|
176
175
|
B_Start_Loc,
|
177
176
|
B_Seqlen,
|
178
177
|
max_len_in_batch,
|
178
|
+
sm_scale,
|
179
179
|
logit_cap,
|
180
180
|
):
|
181
181
|
BLOCK = 32
|
@@ -183,7 +183,6 @@ def _token_att_m_fwd(
|
|
183
183
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
184
184
|
assert Lq == Lk
|
185
185
|
assert Lk in {16, 32, 64, 128, 256}
|
186
|
-
sm_scale = 1.0 / (Lk**0.5)
|
187
186
|
|
188
187
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
189
188
|
|
@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
|
|
250
249
|
b_req_idx,
|
251
250
|
b_start_loc,
|
252
251
|
b_seq_len,
|
253
|
-
other_kv_index,
|
254
252
|
):
|
255
253
|
BLOCK = 64
|
256
254
|
batch, head = b_seq_len.shape[0], logics.shape[0]
|
@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
|
|
277
275
|
o.stride(0),
|
278
276
|
o.stride(1),
|
279
277
|
req_to_tokens.stride(0),
|
280
|
-
other_kv_index,
|
281
278
|
)
|
282
279
|
return
|
283
280
|
|
@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
|
|
295
292
|
o.stride(0),
|
296
293
|
o.stride(1),
|
297
294
|
req_to_tokens.stride(0),
|
298
|
-
other_kv_index,
|
299
295
|
kv_group_num=kv_group_num,
|
300
296
|
BLOCK_DMODEL=v_buffer.shape[-1],
|
301
297
|
BLOCK_N=BLOCK,
|
@@ -315,8 +311,8 @@ def token_attention_fwd(
|
|
315
311
|
b_start_loc,
|
316
312
|
b_seq_len,
|
317
313
|
max_len_in_batch,
|
318
|
-
other_kv_index,
|
319
314
|
total_num_tokens,
|
315
|
+
sm_scale,
|
320
316
|
logit_cap=-1,
|
321
317
|
att_m=None,
|
322
318
|
):
|
@@ -334,6 +330,7 @@ def token_attention_fwd(
|
|
334
330
|
b_start_loc,
|
335
331
|
b_seq_len,
|
336
332
|
max_len_in_batch,
|
333
|
+
sm_scale,
|
337
334
|
logit_cap,
|
338
335
|
)
|
339
336
|
_token_softmax_reducev_fwd(
|
@@ -344,5 +341,4 @@ def token_attention_fwd(
|
|
344
341
|
b_req_idx,
|
345
342
|
b_start_loc,
|
346
343
|
b_seq_len,
|
347
|
-
other_kv_index,
|
348
344
|
)
|
@@ -0,0 +1,172 @@
|
|
1
|
+
"""Run the model with cuda graph."""
|
2
|
+
|
3
|
+
import bisect
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from vllm.distributed.parallel_state import graph_capture
|
7
|
+
|
8
|
+
from sglang.global_config import global_config
|
9
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
10
|
+
from sglang.srt.managers.controller.infer_batch import (
|
11
|
+
Batch, ForwardMode, InputMetadata, init_flashinfer_args
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class CudaGraphRunner:
|
16
|
+
def __init__(self, model_runner, max_batch_size_to_capture):
|
17
|
+
self.model_runner = model_runner
|
18
|
+
self.graphs = {}
|
19
|
+
self.input_buffers = {}
|
20
|
+
self.output_buffers = {}
|
21
|
+
self.flashinfer_handlers = {}
|
22
|
+
self.graph_memory_pool = None
|
23
|
+
|
24
|
+
# Common inputs
|
25
|
+
self.max_bs = max_batch_size_to_capture
|
26
|
+
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
27
|
+
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
28
|
+
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
29
|
+
self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
30
|
+
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
31
|
+
|
32
|
+
# FlashInfer inputs
|
33
|
+
self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
|
34
|
+
self.flashinfer_kv_indptr = torch.zeros(
|
35
|
+
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
36
|
+
)
|
37
|
+
self.flashinfer_kv_indices = torch.zeros(
|
38
|
+
(self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
|
39
|
+
)
|
40
|
+
self.flashinfer_kv_last_page_len = torch.ones(
|
41
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
42
|
+
)
|
43
|
+
|
44
|
+
def can_run(self, batch_size):
|
45
|
+
return batch_size < self.max_bs
|
46
|
+
|
47
|
+
def capture(self, batch_size_list):
|
48
|
+
self.batch_size_list = batch_size_list
|
49
|
+
with graph_capture() as graph_capture_context:
|
50
|
+
self.stream = graph_capture_context.stream
|
51
|
+
for bs in batch_size_list:
|
52
|
+
graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
|
53
|
+
self.graphs[bs] = graph
|
54
|
+
self.input_buffers[bs] = input_buffers
|
55
|
+
self.output_buffers[bs] = output_buffers
|
56
|
+
self.flashinfer_handlers[bs] = flashinfer_handler
|
57
|
+
|
58
|
+
def capture_one_batch_size(self, bs):
|
59
|
+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
60
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
61
|
+
|
62
|
+
graph = torch.cuda.CUDAGraph()
|
63
|
+
stream = self.stream
|
64
|
+
|
65
|
+
# Common inputs
|
66
|
+
input_ids = self.input_ids[:bs]
|
67
|
+
req_pool_indices = self.req_pool_indices[:bs]
|
68
|
+
seq_lens = self.seq_lens[:bs]
|
69
|
+
position_ids_offsets = self.position_ids_offsets[:bs]
|
70
|
+
out_cache_loc = self.out_cache_loc[:bs]
|
71
|
+
|
72
|
+
# FlashInfer inputs
|
73
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
74
|
+
self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
|
75
|
+
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
76
|
+
):
|
77
|
+
use_tensor_cores = True
|
78
|
+
else:
|
79
|
+
use_tensor_cores = False
|
80
|
+
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
81
|
+
self.flashinfer_workspace_buffer, "NHD",
|
82
|
+
use_cuda_graph=True,
|
83
|
+
use_tensor_cores=use_tensor_cores,
|
84
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
|
85
|
+
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
86
|
+
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
87
|
+
)
|
88
|
+
init_flashinfer_args(
|
89
|
+
ForwardMode.DECODE,
|
90
|
+
self.model_runner,
|
91
|
+
req_pool_indices,
|
92
|
+
seq_lens,
|
93
|
+
None,
|
94
|
+
flashinfer_decode_wrapper,
|
95
|
+
)
|
96
|
+
|
97
|
+
# Run and capture
|
98
|
+
def run_once():
|
99
|
+
input_metadata = InputMetadata.create(
|
100
|
+
self.model_runner,
|
101
|
+
forward_mode=ForwardMode.DECODE,
|
102
|
+
req_pool_indices=req_pool_indices,
|
103
|
+
seq_lens=seq_lens,
|
104
|
+
prefix_lens=None,
|
105
|
+
position_ids_offsets=position_ids_offsets,
|
106
|
+
out_cache_loc=out_cache_loc,
|
107
|
+
return_logprob=False,
|
108
|
+
top_logprobs_nums=0,
|
109
|
+
skip_flashinfer_init=True,
|
110
|
+
)
|
111
|
+
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
|
112
|
+
return self.model_runner.model.forward(
|
113
|
+
input_ids, input_metadata.positions, input_metadata
|
114
|
+
)
|
115
|
+
|
116
|
+
for _ in range(2):
|
117
|
+
run_once()
|
118
|
+
|
119
|
+
torch.cuda.synchronize()
|
120
|
+
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
121
|
+
out = run_once()
|
122
|
+
torch.cuda.synchronize()
|
123
|
+
self.graph_memory_pool = graph.pool()
|
124
|
+
return graph, None, out, flashinfer_decode_wrapper
|
125
|
+
|
126
|
+
def replay(self, batch: Batch):
|
127
|
+
assert batch.out_cache_loc is not None
|
128
|
+
assert not batch.return_logprob
|
129
|
+
raw_bs = len(batch.reqs)
|
130
|
+
|
131
|
+
# Pad
|
132
|
+
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
133
|
+
bs = self.batch_size_list[index]
|
134
|
+
if bs != raw_bs:
|
135
|
+
self.seq_lens.zero_()
|
136
|
+
self.position_ids_offsets.fill_(1)
|
137
|
+
self.out_cache_loc.zero_()
|
138
|
+
|
139
|
+
# Common inputs
|
140
|
+
self.input_ids[:raw_bs] = batch.input_ids
|
141
|
+
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
|
142
|
+
self.seq_lens[:raw_bs] = batch.seq_lens
|
143
|
+
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
144
|
+
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
145
|
+
|
146
|
+
# FlashInfer inputs
|
147
|
+
init_flashinfer_args(
|
148
|
+
ForwardMode.DECODE,
|
149
|
+
self.model_runner,
|
150
|
+
self.req_pool_indices[:bs],
|
151
|
+
self.seq_lens[:bs],
|
152
|
+
None,
|
153
|
+
self.flashinfer_handlers[bs],
|
154
|
+
)
|
155
|
+
|
156
|
+
# Replay
|
157
|
+
self.graphs[bs].replay()
|
158
|
+
output = self.output_buffers[bs]
|
159
|
+
|
160
|
+
# Unpad
|
161
|
+
if bs == raw_bs:
|
162
|
+
return output
|
163
|
+
else:
|
164
|
+
output = LogitProcessorOutput(
|
165
|
+
next_token_logits=output.next_token_logits[:raw_bs],
|
166
|
+
next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
|
167
|
+
normalized_prompt_logprobs=None,
|
168
|
+
prefill_token_logprobs=None,
|
169
|
+
prefill_top_logprobs=None,
|
170
|
+
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
|
171
|
+
)
|
172
|
+
return output
|