sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +976 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -2
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +39 -24
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
- sglang/srt/managers/controller/infer_batch.py +90 -63
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +41 -26
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +136 -149
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +32 -11
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +81 -23
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +132 -84
- sglang/srt/server_args.py +35 -21
- sglang/srt/utils.py +65 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
- sglang-0.1.24.dist-info/RECORD +105 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,8 @@ from torch import nn
|
|
7
7
|
from sglang.global_config import global_config
|
8
8
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
9
9
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
10
|
-
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
11
10
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
11
|
+
from sglang.srt.server import global_server_args_dict
|
12
12
|
|
13
13
|
|
14
14
|
class RadixAttention(nn.Module):
|
@@ -85,32 +85,47 @@ class RadixAttention(nn.Module):
|
|
85
85
|
return o
|
86
86
|
|
87
87
|
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
88
|
-
|
89
|
-
|
90
|
-
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
91
|
-
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
92
|
-
causal=True,
|
93
|
-
sm_scale=self.scaling,
|
94
|
-
logits_soft_cap=self.logit_cap,
|
95
|
-
)
|
88
|
+
if not input_metadata.use_ragged:
|
89
|
+
self.store_kv_cache(k, v, input_metadata)
|
96
90
|
|
97
|
-
|
98
|
-
o = o1
|
99
|
-
else:
|
100
|
-
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
91
|
+
o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
|
101
92
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
102
|
-
input_metadata.token_to_kv_pool.
|
103
|
-
causal=
|
93
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
94
|
+
causal=True,
|
104
95
|
sm_scale=self.scaling,
|
105
96
|
logits_soft_cap=self.logit_cap,
|
106
97
|
)
|
98
|
+
else:
|
99
|
+
o1, s1 = (
|
100
|
+
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
101
|
+
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
102
|
+
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
103
|
+
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
104
|
+
causal=True,
|
105
|
+
sm_scale=self.scaling,
|
106
|
+
logits_soft_cap=self.logit_cap,
|
107
|
+
)
|
108
|
+
)
|
107
109
|
|
108
|
-
|
110
|
+
if input_metadata.extend_no_prefix:
|
111
|
+
o = o1
|
112
|
+
else:
|
113
|
+
o2, s2 = (
|
114
|
+
input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
115
|
+
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
116
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
117
|
+
causal=False,
|
118
|
+
sm_scale=self.scaling,
|
119
|
+
logits_soft_cap=self.logit_cap,
|
120
|
+
)
|
121
|
+
)
|
109
122
|
|
110
|
-
|
123
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
124
|
+
|
125
|
+
self.store_kv_cache(k, v, input_metadata)
|
111
126
|
|
112
|
-
|
113
|
-
|
127
|
+
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
128
|
+
torch.cuda.synchronize()
|
114
129
|
|
115
130
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
116
131
|
|
@@ -119,7 +134,7 @@ class RadixAttention(nn.Module):
|
|
119
134
|
|
120
135
|
o = input_metadata.flashinfer_decode_wrapper.forward(
|
121
136
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
122
|
-
input_metadata.token_to_kv_pool.
|
137
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
123
138
|
sm_scale=self.scaling,
|
124
139
|
logits_soft_cap=self.logit_cap,
|
125
140
|
)
|
@@ -136,7 +151,7 @@ class RadixAttention(nn.Module):
|
|
136
151
|
return self.decode_forward(q, k, v, input_metadata)
|
137
152
|
|
138
153
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
154
|
+
k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
155
|
+
v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
156
|
+
k_cache[input_metadata.out_cache_loc] = cache_k
|
157
|
+
v_cache[input_metadata.out_cache_loc] = cache_v
|
@@ -5,8 +5,7 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from sglang.srt.
|
9
|
-
from sglang.srt.utils import wrap_kernel_launcher
|
8
|
+
from sglang.srt.server import global_server_args_dict
|
10
9
|
|
11
10
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
12
11
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
|
|
162
161
|
tl.store(out_ptrs, acc)
|
163
162
|
|
164
163
|
|
165
|
-
cached_kernel_stage1 = None
|
166
|
-
cached_kernel_stage2 = None
|
167
|
-
|
168
|
-
|
169
164
|
def _token_att_m_fwd(
|
170
165
|
q,
|
171
166
|
k_buffer,
|
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
|
|
194
189
|
else:
|
195
190
|
num_warps = 2
|
196
191
|
|
197
|
-
global cached_kernel_stage1
|
198
|
-
if cached_kernel_stage1:
|
199
|
-
cached_kernel_stage1(
|
200
|
-
grid,
|
201
|
-
num_warps,
|
202
|
-
q,
|
203
|
-
k_buffer,
|
204
|
-
sm_scale,
|
205
|
-
Req_to_tokens,
|
206
|
-
B_req_idx,
|
207
|
-
B_Start_Loc,
|
208
|
-
B_Seqlen,
|
209
|
-
att_out,
|
210
|
-
Req_to_tokens.stride(0),
|
211
|
-
q.stride(0),
|
212
|
-
q.stride(1),
|
213
|
-
k_buffer.stride(0),
|
214
|
-
k_buffer.stride(1),
|
215
|
-
att_out.stride(0),
|
216
|
-
)
|
217
|
-
return
|
218
|
-
|
219
192
|
_fwd_kernel_stage1[grid](
|
220
193
|
q,
|
221
194
|
k_buffer,
|
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
|
|
238
211
|
num_warps=num_warps,
|
239
212
|
num_stages=1,
|
240
213
|
)
|
241
|
-
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
|
242
214
|
|
243
215
|
|
244
216
|
def _token_softmax_reducev_fwd(
|
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
|
|
257
229
|
|
258
230
|
num_warps = 1
|
259
231
|
|
260
|
-
global cached_kernel_stage2
|
261
|
-
if cached_kernel_stage2:
|
262
|
-
cached_kernel_stage2(
|
263
|
-
grid,
|
264
|
-
num_warps,
|
265
|
-
logics,
|
266
|
-
v_buffer,
|
267
|
-
o,
|
268
|
-
req_to_tokens,
|
269
|
-
b_req_idx,
|
270
|
-
b_start_loc,
|
271
|
-
b_seq_len,
|
272
|
-
logics.stride(0),
|
273
|
-
v_buffer.stride(0),
|
274
|
-
v_buffer.stride(1),
|
275
|
-
o.stride(0),
|
276
|
-
o.stride(1),
|
277
|
-
req_to_tokens.stride(0),
|
278
|
-
)
|
279
|
-
return
|
280
|
-
|
281
232
|
_fwd_kernel_stage2[grid](
|
282
233
|
logics,
|
283
234
|
v_buffer,
|
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
|
|
298
249
|
num_warps=num_warps,
|
299
250
|
num_stages=3,
|
300
251
|
)
|
301
|
-
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
|
302
252
|
|
303
253
|
|
304
254
|
def token_attention_fwd(
|
@@ -1,11 +1,14 @@
|
|
1
1
|
"""Run the model with cuda graph."""
|
2
2
|
|
3
3
|
import bisect
|
4
|
+
from contextlib import contextmanager
|
4
5
|
|
5
6
|
import torch
|
7
|
+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
8
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
6
9
|
from vllm.distributed.parallel_state import graph_capture
|
10
|
+
from vllm.model_executor.custom_op import CustomOp
|
7
11
|
|
8
|
-
from sglang.global_config import global_config
|
9
12
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
10
13
|
from sglang.srt.managers.controller.infer_batch import (
|
11
14
|
Batch,
|
@@ -13,10 +16,44 @@ from sglang.srt.managers.controller.infer_batch import (
|
|
13
16
|
InputMetadata,
|
14
17
|
init_flashinfer_args,
|
15
18
|
)
|
19
|
+
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
20
|
+
|
21
|
+
|
22
|
+
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
23
|
+
for sub in model._modules.values():
|
24
|
+
if isinstance(sub, CustomOp):
|
25
|
+
if reverse:
|
26
|
+
sub._forward_method = sub.forward_cuda
|
27
|
+
else:
|
28
|
+
sub._forward_method = sub.forward_native
|
29
|
+
if isinstance(sub, torch.nn.Module):
|
30
|
+
_to_torch(sub, reverse)
|
31
|
+
|
32
|
+
|
33
|
+
@contextmanager
|
34
|
+
def patch_model(
|
35
|
+
model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
|
36
|
+
):
|
37
|
+
backup_ca_comm = None
|
38
|
+
|
39
|
+
try:
|
40
|
+
if use_compile:
|
41
|
+
_to_torch(model)
|
42
|
+
monkey_patch_vllm_all_gather()
|
43
|
+
backup_ca_comm = tp_group.ca_comm
|
44
|
+
tp_group.ca_comm = None
|
45
|
+
yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
46
|
+
else:
|
47
|
+
yield model.forward
|
48
|
+
finally:
|
49
|
+
if use_compile:
|
50
|
+
_to_torch(model, reverse=True)
|
51
|
+
monkey_patch_vllm_all_gather(reverse=True)
|
52
|
+
tp_group.ca_comm = backup_ca_comm
|
16
53
|
|
17
54
|
|
18
55
|
class CudaGraphRunner:
|
19
|
-
def __init__(self, model_runner, max_batch_size_to_capture):
|
56
|
+
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
|
20
57
|
self.model_runner = model_runner
|
21
58
|
self.graphs = {}
|
22
59
|
self.input_buffers = {}
|
@@ -54,6 +91,8 @@ class CudaGraphRunner:
|
|
54
91
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
55
92
|
)
|
56
93
|
|
94
|
+
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
95
|
+
|
57
96
|
def can_run(self, batch_size):
|
58
97
|
return batch_size < self.max_bs
|
59
98
|
|
@@ -62,21 +101,23 @@ class CudaGraphRunner:
|
|
62
101
|
with graph_capture() as graph_capture_context:
|
63
102
|
self.stream = graph_capture_context.stream
|
64
103
|
for bs in batch_size_list:
|
65
|
-
(
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
104
|
+
with patch_model(
|
105
|
+
self.model_runner.model,
|
106
|
+
bs in self.compile_bs,
|
107
|
+
self.model_runner.tp_group,
|
108
|
+
) as forward:
|
109
|
+
(
|
110
|
+
graph,
|
111
|
+
input_buffers,
|
112
|
+
output_buffers,
|
113
|
+
flashinfer_handler,
|
114
|
+
) = self.capture_one_batch_size(bs, forward)
|
115
|
+
self.graphs[bs] = graph
|
116
|
+
self.input_buffers[bs] = input_buffers
|
117
|
+
self.output_buffers[bs] = output_buffers
|
118
|
+
self.flashinfer_handlers[bs] = flashinfer_handler
|
119
|
+
|
120
|
+
def capture_one_batch_size(self, bs, forward):
|
80
121
|
graph = torch.cuda.CUDAGraph()
|
81
122
|
stream = self.stream
|
82
123
|
|
@@ -129,9 +170,8 @@ class CudaGraphRunner:
|
|
129
170
|
skip_flashinfer_init=True,
|
130
171
|
)
|
131
172
|
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
|
132
|
-
|
133
|
-
|
134
|
-
)
|
173
|
+
|
174
|
+
return forward(input_ids, input_metadata.positions, input_metadata)
|
135
175
|
|
136
176
|
for _ in range(2):
|
137
177
|
run_once()
|
@@ -152,8 +192,8 @@ class CudaGraphRunner:
|
|
152
192
|
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
153
193
|
bs = self.batch_size_list[index]
|
154
194
|
if bs != raw_bs:
|
155
|
-
self.seq_lens.
|
156
|
-
self.position_ids_offsets.
|
195
|
+
self.seq_lens.fill_(1)
|
196
|
+
self.position_ids_offsets.zero_()
|
157
197
|
self.out_cache_loc.zero_()
|
158
198
|
|
159
199
|
# Common inputs
|
@@ -183,14 +223,18 @@ class CudaGraphRunner:
|
|
183
223
|
else:
|
184
224
|
output = LogitProcessorOutput(
|
185
225
|
next_token_logits=output.next_token_logits[:raw_bs],
|
186
|
-
next_token_logprobs=
|
187
|
-
|
188
|
-
|
226
|
+
next_token_logprobs=(
|
227
|
+
output.next_token_logprobs[:raw_bs]
|
228
|
+
if output.next_token_logprobs is not None
|
229
|
+
else None
|
230
|
+
),
|
189
231
|
normalized_prompt_logprobs=None,
|
190
232
|
prefill_token_logprobs=None,
|
191
233
|
prefill_top_logprobs=None,
|
192
|
-
decode_top_logprobs=
|
193
|
-
|
194
|
-
|
234
|
+
decode_top_logprobs=(
|
235
|
+
output.decode_top_logprobs[:raw_bs]
|
236
|
+
if output.decode_top_logprobs is not None
|
237
|
+
else None
|
238
|
+
),
|
195
239
|
)
|
196
240
|
return output
|
@@ -7,7 +7,9 @@ from typing import List, Union
|
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
10
|
+
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
10
11
|
|
12
|
+
from sglang.global_config import global_config
|
11
13
|
from sglang.srt.constrained import RegexGuide
|
12
14
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
13
15
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
@@ -15,9 +17,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|
15
17
|
|
16
18
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
17
19
|
|
18
|
-
# Store some global server args
|
19
|
-
global_server_args_dict = {}
|
20
|
-
|
21
20
|
|
22
21
|
class ForwardMode(IntEnum):
|
23
22
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
@@ -84,6 +83,15 @@ class Req:
|
|
84
83
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
85
84
|
|
86
85
|
# For incremental decoding
|
86
|
+
# ----- | --------- read_ids -------|
|
87
|
+
# ----- | surr_ids |
|
88
|
+
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
|
89
|
+
# ----- ^ ----------- ^ ----------- ^
|
90
|
+
# ----- 1 ----------- 2 ----------- 3
|
91
|
+
# 1: surr_offset
|
92
|
+
# 2: read_offset
|
93
|
+
# 3: last token
|
94
|
+
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
87
95
|
self.decoded_text = ""
|
88
96
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
89
97
|
self.read_offset = None
|
@@ -134,7 +142,7 @@ class Req:
|
|
134
142
|
return self.finished_reason is not None
|
135
143
|
|
136
144
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
137
|
-
def
|
145
|
+
def init_incremental_detokenize(self):
|
138
146
|
first_iter = self.surr_offset is None or self.read_offset is None
|
139
147
|
|
140
148
|
if first_iter:
|
@@ -144,13 +152,11 @@ class Req:
|
|
144
152
|
)
|
145
153
|
|
146
154
|
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
147
|
-
|
148
|
-
read_ids = all_ids[self.surr_offset :]
|
149
|
-
|
150
|
-
return surr_ids, read_ids, len(all_ids)
|
155
|
+
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
151
156
|
|
152
|
-
def
|
153
|
-
|
157
|
+
def get_next_inc_detokenization(self):
|
158
|
+
read_ids, read_offset = self.init_incremental_detokenize()
|
159
|
+
surr_ids = read_ids[:read_offset]
|
154
160
|
|
155
161
|
surr_text = self.tokenizer.decode(
|
156
162
|
surr_ids,
|
@@ -164,13 +170,7 @@ class Req:
|
|
164
170
|
)
|
165
171
|
|
166
172
|
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
|
167
|
-
|
168
|
-
if inplace:
|
169
|
-
self.decoded_text += new_text
|
170
|
-
self.surr_offset = self.read_offset
|
171
|
-
self.read_offset = num_all_tokens
|
172
|
-
|
173
|
-
return True, new_text
|
173
|
+
return True, new_text[len(surr_text) :]
|
174
174
|
|
175
175
|
return False, ""
|
176
176
|
|
@@ -272,6 +272,7 @@ class Batch:
|
|
272
272
|
prefix_lens: torch.Tensor = None
|
273
273
|
position_ids_offsets: torch.Tensor = None
|
274
274
|
out_cache_loc: torch.Tensor = None
|
275
|
+
extend_num_tokens: int = None
|
275
276
|
|
276
277
|
# For processing logprobs
|
277
278
|
return_logprob: bool = False
|
@@ -282,10 +283,6 @@ class Batch:
|
|
282
283
|
image_sizes: List[List[int]] = None
|
283
284
|
image_offsets: List[int] = None
|
284
285
|
|
285
|
-
# Other arguments for control
|
286
|
-
output_ids: torch.Tensor = None
|
287
|
-
extend_num_tokens: int = None
|
288
|
-
|
289
286
|
# Batched sampling params
|
290
287
|
temperatures: torch.Tensor = None
|
291
288
|
top_ps: torch.Tensor = None
|
@@ -327,6 +324,13 @@ class Batch:
|
|
327
324
|
seq_lens = []
|
328
325
|
|
329
326
|
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
327
|
+
|
328
|
+
if req_pool_indices is None:
|
329
|
+
raise RuntimeError(
|
330
|
+
"Out of memory. "
|
331
|
+
"Please set a smaller number for `--max-running-requests`."
|
332
|
+
)
|
333
|
+
|
330
334
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
331
335
|
for i in range(bs):
|
332
336
|
flatten_input_ids.extend(input_ids[i])
|
@@ -398,10 +402,10 @@ class Batch:
|
|
398
402
|
).view(-1, 1)
|
399
403
|
self.top_ps = torch.tensor(
|
400
404
|
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
401
|
-
)
|
405
|
+
)
|
402
406
|
self.top_ks = torch.tensor(
|
403
407
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
404
|
-
)
|
408
|
+
)
|
405
409
|
self.frequency_penalties = torch.tensor(
|
406
410
|
[r.sampling_params.frequency_penalty for r in reqs],
|
407
411
|
dtype=torch.float,
|
@@ -428,7 +432,8 @@ class Batch:
|
|
428
432
|
|
429
433
|
def retract_decode(self):
|
430
434
|
sorted_indices = [i for i in range(len(self.reqs))]
|
431
|
-
|
435
|
+
|
436
|
+
# TODO(lsyin): improve retraction policy for radix cache
|
432
437
|
sorted_indices.sort(
|
433
438
|
key=lambda i: (
|
434
439
|
len(self.reqs[i].output_ids),
|
@@ -440,7 +445,17 @@ class Batch:
|
|
440
445
|
retracted_reqs = []
|
441
446
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
442
447
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
443
|
-
while
|
448
|
+
while (
|
449
|
+
self.token_to_kv_pool.available_size()
|
450
|
+
< len(sorted_indices) * global_config.retract_decode_steps
|
451
|
+
):
|
452
|
+
if len(sorted_indices) == 1:
|
453
|
+
# Corner case: only one request left
|
454
|
+
assert (
|
455
|
+
self.token_to_kv_pool.available_size() > 0
|
456
|
+
), "No space left for only one request"
|
457
|
+
break
|
458
|
+
|
444
459
|
idx = sorted_indices.pop()
|
445
460
|
req = self.reqs[idx]
|
446
461
|
retracted_reqs.append(req)
|
@@ -465,7 +480,16 @@ class Batch:
|
|
465
480
|
|
466
481
|
self.filter_batch(sorted_indices)
|
467
482
|
|
468
|
-
|
483
|
+
# Reqs in batch are filtered
|
484
|
+
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
|
485
|
+
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
|
486
|
+
|
487
|
+
new_estimate_ratio = (
|
488
|
+
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
|
489
|
+
) / total_max_new_tokens
|
490
|
+
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
491
|
+
|
492
|
+
return retracted_reqs, new_estimate_ratio
|
469
493
|
|
470
494
|
def check_for_jump_forward(self, model_runner):
|
471
495
|
jump_forward_reqs = []
|
@@ -499,7 +523,7 @@ class Batch:
|
|
499
523
|
cur_output_ids = req.output_ids
|
500
524
|
|
501
525
|
req.output_ids.extend(suffix_ids)
|
502
|
-
decode_res, new_text = req.
|
526
|
+
decode_res, new_text = req.get_next_inc_detokenization()
|
503
527
|
if not decode_res:
|
504
528
|
req.output_ids = cur_output_ids
|
505
529
|
continue
|
@@ -518,6 +542,9 @@ class Batch:
|
|
518
542
|
req.output_ids = cur_output_ids
|
519
543
|
continue
|
520
544
|
|
545
|
+
# The decode status has diverged from detokenizer_manager
|
546
|
+
req.vid += 1
|
547
|
+
|
521
548
|
# insert the old request into tree_cache
|
522
549
|
if req_pool_indices_cpu is None:
|
523
550
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
@@ -659,20 +686,20 @@ class Batch:
|
|
659
686
|
|
660
687
|
# TODO(lmzheng): apply penalty
|
661
688
|
probs = torch.softmax(logits, dim=-1)
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
sampled_index = torch.ones(
|
668
|
-
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
|
669
|
-
)
|
670
|
-
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
671
|
-
-1
|
689
|
+
|
690
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
691
|
+
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
692
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
693
|
+
probs, uniform_samples, self.top_ks, self.top_ps
|
672
694
|
)
|
673
|
-
|
674
|
-
|
675
|
-
|
695
|
+
|
696
|
+
if torch.any(~success):
|
697
|
+
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
698
|
+
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
699
|
+
argmax_ids = torch.argmax(probs, dim=-1)
|
700
|
+
batch_next_token_ids = torch.where(
|
701
|
+
success, batch_next_token_ids, argmax_ids
|
702
|
+
)
|
676
703
|
|
677
704
|
if has_regex:
|
678
705
|
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
@@ -682,18 +709,7 @@ class Batch:
|
|
682
709
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
683
710
|
)
|
684
711
|
|
685
|
-
return batch_next_token_ids
|
686
|
-
|
687
|
-
|
688
|
-
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
689
|
-
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
690
|
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
691
|
-
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
692
|
-
probs_sort[
|
693
|
-
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
694
|
-
] = 0.0
|
695
|
-
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
696
|
-
return probs_sort, probs_idx
|
712
|
+
return batch_next_token_ids
|
697
713
|
|
698
714
|
|
699
715
|
@dataclass
|
@@ -731,6 +747,7 @@ class InputMetadata:
|
|
731
747
|
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
732
748
|
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
733
749
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
750
|
+
use_ragged: bool = False
|
734
751
|
|
735
752
|
@classmethod
|
736
753
|
def create(
|
@@ -746,7 +763,10 @@ class InputMetadata:
|
|
746
763
|
return_logprob=False,
|
747
764
|
skip_flashinfer_init=False,
|
748
765
|
):
|
766
|
+
use_ragged = False
|
749
767
|
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
768
|
+
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
769
|
+
use_ragged = True
|
750
770
|
init_flashinfer_args(
|
751
771
|
forward_mode,
|
752
772
|
model_runner,
|
@@ -754,6 +774,7 @@ class InputMetadata:
|
|
754
774
|
seq_lens,
|
755
775
|
prefix_lens,
|
756
776
|
model_runner.flashinfer_decode_wrapper,
|
777
|
+
use_ragged,
|
757
778
|
)
|
758
779
|
|
759
780
|
batch_size = len(req_pool_indices)
|
@@ -808,6 +829,7 @@ class InputMetadata:
|
|
808
829
|
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
809
830
|
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
810
831
|
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
832
|
+
use_ragged=use_ragged,
|
811
833
|
)
|
812
834
|
|
813
835
|
if model_runner.server_args.disable_flashinfer:
|
@@ -828,16 +850,19 @@ def init_flashinfer_args(
|
|
828
850
|
seq_lens,
|
829
851
|
prefix_lens,
|
830
852
|
flashinfer_decode_wrapper,
|
853
|
+
use_ragged=False,
|
831
854
|
):
|
855
|
+
"""Init auxiliary variables for FlashInfer attention backend."""
|
832
856
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
833
857
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
834
858
|
head_dim = model_runner.model_config.head_dim
|
835
859
|
batch_size = len(req_pool_indices)
|
860
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
836
861
|
|
837
|
-
if
|
838
|
-
paged_kernel_lens = seq_lens
|
839
|
-
else:
|
862
|
+
if use_ragged:
|
840
863
|
paged_kernel_lens = prefix_lens
|
864
|
+
else:
|
865
|
+
paged_kernel_lens = seq_lens
|
841
866
|
|
842
867
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
843
868
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
@@ -870,14 +895,15 @@ def init_flashinfer_args(
|
|
870
895
|
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
871
896
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
872
897
|
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
898
|
+
if use_ragged:
|
899
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
900
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
901
|
+
qo_indptr,
|
902
|
+
qo_indptr,
|
903
|
+
num_qo_heads,
|
904
|
+
num_kv_heads,
|
905
|
+
head_dim,
|
906
|
+
)
|
881
907
|
|
882
908
|
# cached part
|
883
909
|
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
@@ -894,6 +920,7 @@ def init_flashinfer_args(
|
|
894
920
|
|
895
921
|
|
896
922
|
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
923
|
+
"""Init auxiliary variables for triton attention backend."""
|
897
924
|
batch_size = len(seq_lens)
|
898
925
|
max_seq_len = int(torch.max(seq_lens))
|
899
926
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|