sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/runtime_endpoint.py +14 -4
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -20
- sglang/bench_serving.py +758 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -1
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/chat_template.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +31 -5
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
- sglang/srt/managers/controller/infer_batch.py +76 -72
- sglang/srt/managers/controller/manager_multi.py +109 -98
- sglang/srt/managers/controller/manager_single.py +105 -50
- sglang/srt/managers/controller/model_runner.py +42 -18
- sglang/srt/managers/controller/radix_cache.py +4 -3
- sglang/srt/managers/controller/schedule_heuristic.py +4 -0
- sglang/srt/managers/controller/tp_worker.py +143 -156
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +46 -58
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +65 -16
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +2 -8
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +130 -108
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +114 -90
- sglang/srt/server_args.py +27 -17
- sglang/srt/utils.py +17 -118
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
- sglang-0.1.22.dist-info/RECORD +103 -0
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
- sglang-0.1.20.dist-info/RECORD +0 -82
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,8 @@ from torch import nn
|
|
7
7
|
from sglang.global_config import global_config
|
8
8
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
9
9
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
10
|
-
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
11
10
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
11
|
+
from sglang.srt.server import global_server_args_dict
|
12
12
|
|
13
13
|
|
14
14
|
class RadixAttention(nn.Module):
|
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
|
|
136
136
|
return self.decode_forward(q, k, v, input_metadata)
|
137
137
|
|
138
138
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
139
|
+
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
|
140
|
+
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
|
141
|
+
|
142
|
+
|
143
|
+
try:
|
144
|
+
|
145
|
+
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
|
146
|
+
def _store_kv_cache(
|
147
|
+
k: torch.Tensor,
|
148
|
+
v: torch.Tensor,
|
149
|
+
kv_cache: torch.Tensor,
|
150
|
+
cache_loc: torch.Tensor,
|
151
|
+
) -> None:
|
152
|
+
kv_cache[cache_loc, 0] = k
|
153
|
+
kv_cache[cache_loc, 1] = v
|
154
|
+
|
155
|
+
@_store_kv_cache.register_fake
|
156
|
+
def _(k, v, kv_cache, cache_loc):
|
157
|
+
pass
|
158
|
+
|
159
|
+
except:
|
160
|
+
|
161
|
+
def _store_kv_cache(
|
162
|
+
k: torch.Tensor,
|
163
|
+
v: torch.Tensor,
|
164
|
+
kv_cache: torch.Tensor,
|
165
|
+
cache_loc: torch.Tensor,
|
166
|
+
) -> None:
|
167
|
+
kv_cache[cache_loc, 0] = k
|
168
|
+
kv_cache[cache_loc, 1] = v
|
@@ -5,8 +5,7 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from sglang.srt.
|
9
|
-
from sglang.srt.utils import wrap_kernel_launcher
|
8
|
+
from sglang.srt.server import global_server_args_dict
|
10
9
|
|
11
10
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
12
11
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
|
|
162
161
|
tl.store(out_ptrs, acc)
|
163
162
|
|
164
163
|
|
165
|
-
cached_kernel_stage1 = None
|
166
|
-
cached_kernel_stage2 = None
|
167
|
-
|
168
|
-
|
169
164
|
def _token_att_m_fwd(
|
170
165
|
q,
|
171
166
|
k_buffer,
|
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
|
|
194
189
|
else:
|
195
190
|
num_warps = 2
|
196
191
|
|
197
|
-
global cached_kernel_stage1
|
198
|
-
if cached_kernel_stage1:
|
199
|
-
cached_kernel_stage1(
|
200
|
-
grid,
|
201
|
-
num_warps,
|
202
|
-
q,
|
203
|
-
k_buffer,
|
204
|
-
sm_scale,
|
205
|
-
Req_to_tokens,
|
206
|
-
B_req_idx,
|
207
|
-
B_Start_Loc,
|
208
|
-
B_Seqlen,
|
209
|
-
att_out,
|
210
|
-
Req_to_tokens.stride(0),
|
211
|
-
q.stride(0),
|
212
|
-
q.stride(1),
|
213
|
-
k_buffer.stride(0),
|
214
|
-
k_buffer.stride(1),
|
215
|
-
att_out.stride(0),
|
216
|
-
)
|
217
|
-
return
|
218
|
-
|
219
192
|
_fwd_kernel_stage1[grid](
|
220
193
|
q,
|
221
194
|
k_buffer,
|
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
|
|
238
211
|
num_warps=num_warps,
|
239
212
|
num_stages=1,
|
240
213
|
)
|
241
|
-
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
|
242
214
|
|
243
215
|
|
244
216
|
def _token_softmax_reducev_fwd(
|
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
|
|
257
229
|
|
258
230
|
num_warps = 1
|
259
231
|
|
260
|
-
global cached_kernel_stage2
|
261
|
-
if cached_kernel_stage2:
|
262
|
-
cached_kernel_stage2(
|
263
|
-
grid,
|
264
|
-
num_warps,
|
265
|
-
logics,
|
266
|
-
v_buffer,
|
267
|
-
o,
|
268
|
-
req_to_tokens,
|
269
|
-
b_req_idx,
|
270
|
-
b_start_loc,
|
271
|
-
b_seq_len,
|
272
|
-
logics.stride(0),
|
273
|
-
v_buffer.stride(0),
|
274
|
-
v_buffer.stride(1),
|
275
|
-
o.stride(0),
|
276
|
-
o.stride(1),
|
277
|
-
req_to_tokens.stride(0),
|
278
|
-
)
|
279
|
-
return
|
280
|
-
|
281
232
|
_fwd_kernel_stage2[grid](
|
282
233
|
logics,
|
283
234
|
v_buffer,
|
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
|
|
298
249
|
num_warps=num_warps,
|
299
250
|
num_stages=3,
|
300
251
|
)
|
301
|
-
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
|
302
252
|
|
303
253
|
|
304
254
|
def token_attention_fwd(
|
@@ -3,12 +3,16 @@
|
|
3
3
|
import bisect
|
4
4
|
|
5
5
|
import torch
|
6
|
+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
7
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
6
8
|
from vllm.distributed.parallel_state import graph_capture
|
7
9
|
|
8
|
-
from sglang.global_config import global_config
|
9
10
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
10
11
|
from sglang.srt.managers.controller.infer_batch import (
|
11
|
-
Batch,
|
12
|
+
Batch,
|
13
|
+
ForwardMode,
|
14
|
+
InputMetadata,
|
15
|
+
init_flashinfer_args,
|
12
16
|
)
|
13
17
|
|
14
18
|
|
@@ -24,18 +28,28 @@ class CudaGraphRunner:
|
|
24
28
|
# Common inputs
|
25
29
|
self.max_bs = max_batch_size_to_capture
|
26
30
|
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
27
|
-
self.req_pool_indices = torch.zeros(
|
31
|
+
self.req_pool_indices = torch.zeros(
|
32
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
33
|
+
)
|
28
34
|
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
29
|
-
self.position_ids_offsets = torch.zeros(
|
30
|
-
|
35
|
+
self.position_ids_offsets = torch.zeros(
|
36
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
37
|
+
)
|
38
|
+
self.out_cache_loc = torch.zeros(
|
39
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
40
|
+
)
|
31
41
|
|
32
42
|
# FlashInfer inputs
|
33
|
-
self.flashinfer_workspace_buffer =
|
43
|
+
self.flashinfer_workspace_buffer = (
|
44
|
+
self.model_runner.flashinfer_workspace_buffers[0]
|
45
|
+
)
|
34
46
|
self.flashinfer_kv_indptr = torch.zeros(
|
35
47
|
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
36
48
|
)
|
37
49
|
self.flashinfer_kv_indices = torch.zeros(
|
38
|
-
(self.max_bs * model_runner.model_config.context_len,),
|
50
|
+
(self.max_bs * model_runner.model_config.context_len,),
|
51
|
+
dtype=torch.int32,
|
52
|
+
device="cuda",
|
39
53
|
)
|
40
54
|
self.flashinfer_kv_last_page_len = torch.ones(
|
41
55
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
@@ -49,16 +63,18 @@ class CudaGraphRunner:
|
|
49
63
|
with graph_capture() as graph_capture_context:
|
50
64
|
self.stream = graph_capture_context.stream
|
51
65
|
for bs in batch_size_list:
|
52
|
-
|
66
|
+
(
|
67
|
+
graph,
|
68
|
+
input_buffers,
|
69
|
+
output_buffers,
|
70
|
+
flashinfer_handler,
|
71
|
+
) = self.capture_one_batch_size(bs)
|
53
72
|
self.graphs[bs] = graph
|
54
73
|
self.input_buffers[bs] = input_buffers
|
55
74
|
self.output_buffers[bs] = output_buffers
|
56
75
|
self.flashinfer_handlers[bs] = flashinfer_handler
|
57
76
|
|
58
77
|
def capture_one_batch_size(self, bs):
|
59
|
-
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
60
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
61
|
-
|
62
78
|
graph = torch.cuda.CUDAGraph()
|
63
79
|
stream = self.stream
|
64
80
|
|
@@ -71,17 +87,19 @@ class CudaGraphRunner:
|
|
71
87
|
|
72
88
|
# FlashInfer inputs
|
73
89
|
if not _grouped_size_compiled_for_decode_kernels(
|
74
|
-
self.model_runner.model_config.num_attention_heads
|
90
|
+
self.model_runner.model_config.num_attention_heads
|
91
|
+
// self.model_runner.tp_size,
|
75
92
|
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
76
93
|
):
|
77
94
|
use_tensor_cores = True
|
78
95
|
else:
|
79
96
|
use_tensor_cores = False
|
80
97
|
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
81
|
-
self.flashinfer_workspace_buffer,
|
98
|
+
self.flashinfer_workspace_buffer,
|
99
|
+
"NHD",
|
82
100
|
use_cuda_graph=True,
|
83
101
|
use_tensor_cores=use_tensor_cores,
|
84
|
-
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
|
102
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
85
103
|
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
86
104
|
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
87
105
|
)
|
@@ -132,8 +150,8 @@ class CudaGraphRunner:
|
|
132
150
|
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
133
151
|
bs = self.batch_size_list[index]
|
134
152
|
if bs != raw_bs:
|
135
|
-
self.seq_lens.
|
136
|
-
self.position_ids_offsets.
|
153
|
+
self.seq_lens.fill_(1)
|
154
|
+
self.position_ids_offsets.zero_()
|
137
155
|
self.out_cache_loc.zero_()
|
138
156
|
|
139
157
|
# Common inputs
|
@@ -163,10 +181,18 @@ class CudaGraphRunner:
|
|
163
181
|
else:
|
164
182
|
output = LogitProcessorOutput(
|
165
183
|
next_token_logits=output.next_token_logits[:raw_bs],
|
166
|
-
next_token_logprobs=
|
184
|
+
next_token_logprobs=(
|
185
|
+
output.next_token_logprobs[:raw_bs]
|
186
|
+
if output.next_token_logprobs is not None
|
187
|
+
else None
|
188
|
+
),
|
167
189
|
normalized_prompt_logprobs=None,
|
168
190
|
prefill_token_logprobs=None,
|
169
191
|
prefill_top_logprobs=None,
|
170
|
-
decode_top_logprobs=
|
192
|
+
decode_top_logprobs=(
|
193
|
+
output.decode_top_logprobs[:raw_bs]
|
194
|
+
if output.decode_top_logprobs is not None
|
195
|
+
else None
|
196
|
+
),
|
171
197
|
)
|
172
198
|
return output
|
@@ -7,6 +7,7 @@ from typing import List, Union
|
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
10
|
+
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
10
11
|
|
11
12
|
from sglang.srt.constrained import RegexGuide
|
12
13
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
@@ -15,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|
15
16
|
|
16
17
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
17
18
|
|
18
|
-
# Store some global server args
|
19
|
-
global_server_args_dict = {}
|
20
|
-
|
21
19
|
|
22
20
|
class ForwardMode(IntEnum):
|
23
21
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
@@ -84,6 +82,15 @@ class Req:
|
|
84
82
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
85
83
|
|
86
84
|
# For incremental decoding
|
85
|
+
# ----- | --------- read_ids -------|
|
86
|
+
# ----- | surr_ids |
|
87
|
+
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
|
88
|
+
# ----- ^ ----------- ^ ----------- ^
|
89
|
+
# ----- 1 ----------- 2 ----------- 3
|
90
|
+
# 1: surr_offset
|
91
|
+
# 2: read_offset
|
92
|
+
# 3: last token
|
93
|
+
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
87
94
|
self.decoded_text = ""
|
88
95
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
89
96
|
self.read_offset = None
|
@@ -134,7 +141,7 @@ class Req:
|
|
134
141
|
return self.finished_reason is not None
|
135
142
|
|
136
143
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
137
|
-
def
|
144
|
+
def init_incremental_detokenize(self):
|
138
145
|
first_iter = self.surr_offset is None or self.read_offset is None
|
139
146
|
|
140
147
|
if first_iter:
|
@@ -144,13 +151,11 @@ class Req:
|
|
144
151
|
)
|
145
152
|
|
146
153
|
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
147
|
-
|
148
|
-
read_ids = all_ids[self.surr_offset :]
|
149
|
-
|
150
|
-
return surr_ids, read_ids, len(all_ids)
|
154
|
+
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
151
155
|
|
152
|
-
def
|
153
|
-
|
156
|
+
def get_next_inc_detokenization(self):
|
157
|
+
read_ids, read_offset = self.init_incremental_detokenize()
|
158
|
+
surr_ids = read_ids[:read_offset]
|
154
159
|
|
155
160
|
surr_text = self.tokenizer.decode(
|
156
161
|
surr_ids,
|
@@ -164,19 +169,10 @@ class Req:
|
|
164
169
|
)
|
165
170
|
|
166
171
|
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
|
167
|
-
|
168
|
-
if inplace:
|
169
|
-
self.decoded_text += new_text
|
170
|
-
self.surr_offset = self.read_offset
|
171
|
-
self.read_offset = num_all_tokens
|
172
|
-
|
173
|
-
return True, new_text
|
172
|
+
return True, new_text[len(surr_text) :]
|
174
173
|
|
175
174
|
return False, ""
|
176
175
|
|
177
|
-
def max_new_tokens(self):
|
178
|
-
return self.sampling_params.max_new_tokens
|
179
|
-
|
180
176
|
def check_finished(self):
|
181
177
|
if self.finished():
|
182
178
|
return
|
@@ -275,6 +271,7 @@ class Batch:
|
|
275
271
|
prefix_lens: torch.Tensor = None
|
276
272
|
position_ids_offsets: torch.Tensor = None
|
277
273
|
out_cache_loc: torch.Tensor = None
|
274
|
+
extend_num_tokens: int = None
|
278
275
|
|
279
276
|
# For processing logprobs
|
280
277
|
return_logprob: bool = False
|
@@ -285,10 +282,6 @@ class Batch:
|
|
285
282
|
image_sizes: List[List[int]] = None
|
286
283
|
image_offsets: List[int] = None
|
287
284
|
|
288
|
-
# Other arguments for control
|
289
|
-
output_ids: torch.Tensor = None
|
290
|
-
extend_num_tokens: int = None
|
291
|
-
|
292
285
|
# Batched sampling params
|
293
286
|
temperatures: torch.Tensor = None
|
294
287
|
top_ps: torch.Tensor = None
|
@@ -330,6 +323,13 @@ class Batch:
|
|
330
323
|
seq_lens = []
|
331
324
|
|
332
325
|
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
326
|
+
|
327
|
+
if req_pool_indices is None:
|
328
|
+
raise RuntimeError(
|
329
|
+
"Out of memory. "
|
330
|
+
"Please set a smaller number for `--max-running-requests`."
|
331
|
+
)
|
332
|
+
|
333
333
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
334
334
|
for i in range(bs):
|
335
335
|
flatten_input_ids.extend(input_ids[i])
|
@@ -352,7 +352,7 @@ class Batch:
|
|
352
352
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
353
353
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
354
354
|
if out_cache_loc is None:
|
355
|
-
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.
|
355
|
+
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
356
356
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
357
357
|
|
358
358
|
if out_cache_loc is None:
|
@@ -401,10 +401,10 @@ class Batch:
|
|
401
401
|
).view(-1, 1)
|
402
402
|
self.top_ps = torch.tensor(
|
403
403
|
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
404
|
-
)
|
404
|
+
)
|
405
405
|
self.top_ks = torch.tensor(
|
406
406
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
407
|
-
)
|
407
|
+
)
|
408
408
|
self.frequency_penalties = torch.tensor(
|
409
409
|
[r.sampling_params.frequency_penalty for r in reqs],
|
410
410
|
dtype=torch.float,
|
@@ -422,7 +422,7 @@ class Batch:
|
|
422
422
|
if self.token_to_kv_pool.available_size() >= bs:
|
423
423
|
return True
|
424
424
|
|
425
|
-
self.tree_cache.evict(bs, self.token_to_kv_pool.
|
425
|
+
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
426
426
|
|
427
427
|
if self.token_to_kv_pool.available_size() >= bs:
|
428
428
|
return True
|
@@ -453,7 +453,7 @@ class Batch:
|
|
453
453
|
token_indices = self.req_to_token_pool.req_to_token[
|
454
454
|
req_pool_indices_cpu[idx]
|
455
455
|
][last_uncached_pos : seq_lens_cpu[idx]]
|
456
|
-
self.token_to_kv_pool.
|
456
|
+
self.token_to_kv_pool.free(token_indices)
|
457
457
|
|
458
458
|
# release the last node
|
459
459
|
self.tree_cache.dec_lock_ref(req.last_node)
|
@@ -502,7 +502,7 @@ class Batch:
|
|
502
502
|
cur_output_ids = req.output_ids
|
503
503
|
|
504
504
|
req.output_ids.extend(suffix_ids)
|
505
|
-
decode_res, new_text = req.
|
505
|
+
decode_res, new_text = req.get_next_inc_detokenization()
|
506
506
|
if not decode_res:
|
507
507
|
req.output_ids = cur_output_ids
|
508
508
|
continue
|
@@ -521,6 +521,9 @@ class Batch:
|
|
521
521
|
req.output_ids = cur_output_ids
|
522
522
|
continue
|
523
523
|
|
524
|
+
# The decode status has diverged from detokenizer_manager
|
525
|
+
req.vid += 1
|
526
|
+
|
524
527
|
# insert the old request into tree_cache
|
525
528
|
if req_pool_indices_cpu is None:
|
526
529
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
@@ -596,8 +599,7 @@ class Batch:
|
|
596
599
|
"logit_bias",
|
597
600
|
]:
|
598
601
|
self_val = getattr(self, item, None)
|
599
|
-
# logit_bias can be None
|
600
|
-
if self_val is not None:
|
602
|
+
if self_val is not None: # logit_bias can be None
|
601
603
|
setattr(self, item, self_val[new_indices])
|
602
604
|
|
603
605
|
def merge(self, other: "Batch"):
|
@@ -663,18 +665,21 @@ class Batch:
|
|
663
665
|
|
664
666
|
# TODO(lmzheng): apply penalty
|
665
667
|
probs = torch.softmax(logits, dim=-1)
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
|
672
|
-
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
673
|
-
-1
|
668
|
+
|
669
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
670
|
+
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
671
|
+
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
672
|
+
probs, uniform_samples, self.top_ks, self.top_ps
|
674
673
|
)
|
675
|
-
|
676
|
-
|
677
|
-
|
674
|
+
|
675
|
+
# FIXME: this is a temporary fix for the illegal token ids
|
676
|
+
illegal_mask = torch.logical_or(
|
677
|
+
batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
|
678
|
+
)
|
679
|
+
if torch.any(illegal_mask):
|
680
|
+
warnings.warn("Illegal sampled token ids")
|
681
|
+
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
682
|
+
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
678
683
|
|
679
684
|
if has_regex:
|
680
685
|
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
@@ -684,18 +689,7 @@ class Batch:
|
|
684
689
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
685
690
|
)
|
686
691
|
|
687
|
-
return batch_next_token_ids
|
688
|
-
|
689
|
-
|
690
|
-
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
691
|
-
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
692
|
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
693
|
-
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
694
|
-
probs_sort[
|
695
|
-
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
696
|
-
] = 0.0
|
697
|
-
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
698
|
-
return probs_sort, probs_idx
|
692
|
+
return batch_next_token_ids
|
699
693
|
|
700
694
|
|
701
695
|
@dataclass
|
@@ -749,8 +743,14 @@ class InputMetadata:
|
|
749
743
|
skip_flashinfer_init=False,
|
750
744
|
):
|
751
745
|
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
752
|
-
init_flashinfer_args(
|
753
|
-
|
746
|
+
init_flashinfer_args(
|
747
|
+
forward_mode,
|
748
|
+
model_runner,
|
749
|
+
req_pool_indices,
|
750
|
+
seq_lens,
|
751
|
+
prefix_lens,
|
752
|
+
model_runner.flashinfer_decode_wrapper,
|
753
|
+
)
|
754
754
|
|
755
755
|
batch_size = len(req_pool_indices)
|
756
756
|
|
@@ -807,16 +807,25 @@ class InputMetadata:
|
|
807
807
|
)
|
808
808
|
|
809
809
|
if model_runner.server_args.disable_flashinfer:
|
810
|
-
(
|
811
|
-
|
812
|
-
|
813
|
-
|
810
|
+
(
|
811
|
+
ret.triton_max_seq_len,
|
812
|
+
ret.triton_max_extend_len,
|
813
|
+
ret.triton_start_loc,
|
814
|
+
ret.triton_prefix_lens,
|
815
|
+
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
814
816
|
|
815
817
|
return ret
|
816
818
|
|
817
819
|
|
818
|
-
def init_flashinfer_args(
|
819
|
-
|
820
|
+
def init_flashinfer_args(
|
821
|
+
forward_mode,
|
822
|
+
model_runner,
|
823
|
+
req_pool_indices,
|
824
|
+
seq_lens,
|
825
|
+
prefix_lens,
|
826
|
+
flashinfer_decode_wrapper,
|
827
|
+
):
|
828
|
+
"""Init auxiliary variables for FlashInfer attention backend."""
|
820
829
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
821
830
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
822
831
|
head_dim = model_runner.model_config.head_dim
|
@@ -827,9 +836,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|
827
836
|
else:
|
828
837
|
paged_kernel_lens = prefix_lens
|
829
838
|
|
830
|
-
kv_indptr = torch.zeros(
|
831
|
-
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
832
|
-
)
|
839
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
833
840
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
834
841
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
835
842
|
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
@@ -842,9 +849,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|
842
849
|
],
|
843
850
|
dim=0,
|
844
851
|
).contiguous()
|
845
|
-
kv_last_page_len = torch.ones(
|
846
|
-
(batch_size,), dtype=torch.int32, device="cuda"
|
847
|
-
)
|
852
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
848
853
|
|
849
854
|
if forward_mode == ForwardMode.DECODE:
|
850
855
|
flashinfer_decode_wrapper.end_forward()
|
@@ -859,9 +864,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|
859
864
|
)
|
860
865
|
else:
|
861
866
|
# extend part
|
862
|
-
qo_indptr = torch.zeros(
|
863
|
-
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
864
|
-
)
|
867
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
865
868
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
866
869
|
|
867
870
|
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
@@ -888,6 +891,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|
888
891
|
|
889
892
|
|
890
893
|
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
894
|
+
"""Init auxiliary variables for triton attention backend."""
|
891
895
|
batch_size = len(seq_lens)
|
892
896
|
max_seq_len = int(torch.max(seq_lens))
|
893
897
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|