sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
ADDED
@@ -0,0 +1,299 @@
|
|
1
|
+
"""
|
2
|
+
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
3
|
+
|
4
|
+
# Usage (latency test):
|
5
|
+
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
6
|
+
|
7
|
+
# Usage (correctness test):
|
8
|
+
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
9
|
+
|
10
|
+
### Reference output:
|
11
|
+
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
12
|
+
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
13
|
+
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
14
|
+
device='cuda:0', dtype=torch.float16)
|
15
|
+
prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
|
16
|
+
[-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
|
17
|
+
[-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
|
18
|
+
device='cuda:0', dtype=torch.float16)
|
19
|
+
<s> The capital of France is.
|
20
|
+
The capital of the United States is Washington, D.C.
|
21
|
+
|
22
|
+
<s> The capital of the United Kindom is.
|
23
|
+
The capital of the United Kingdom is London.
|
24
|
+
The capital of the
|
25
|
+
<s> Today is a sunny day and I like go for a walk in the park.
|
26
|
+
I'm going to the park
|
27
|
+
"""
|
28
|
+
|
29
|
+
import argparse
|
30
|
+
import dataclasses
|
31
|
+
import logging
|
32
|
+
import multiprocessing
|
33
|
+
import time
|
34
|
+
|
35
|
+
import numpy as np
|
36
|
+
import torch
|
37
|
+
import torch.distributed as dist
|
38
|
+
|
39
|
+
from sglang.srt.hf_transformers_utils import get_tokenizer
|
40
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
|
41
|
+
from sglang.srt.managers.controller.model_runner import ModelRunner
|
42
|
+
from sglang.srt.model_config import ModelConfig
|
43
|
+
from sglang.srt.sampling_params import SamplingParams
|
44
|
+
from sglang.srt.server_args import ServerArgs
|
45
|
+
from sglang.srt.utils import suppress_other_loggers
|
46
|
+
|
47
|
+
|
48
|
+
@dataclasses.dataclass
|
49
|
+
class BenchArgs:
|
50
|
+
batch_size: int = 1
|
51
|
+
input_len: int = 1024
|
52
|
+
output_len: int = 4
|
53
|
+
correctness_test: bool = False
|
54
|
+
# This is only used for correctness test
|
55
|
+
cut_len: int = 4
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
59
|
+
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
|
60
|
+
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
|
61
|
+
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
|
62
|
+
parser.add_argument("--correctness-test", action="store_true")
|
63
|
+
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
67
|
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
68
|
+
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
69
|
+
|
70
|
+
|
71
|
+
def load_model(server_args, tp_rank):
|
72
|
+
suppress_other_loggers()
|
73
|
+
|
74
|
+
model_config = ModelConfig(path=server_args.model_path)
|
75
|
+
model_runner = ModelRunner(
|
76
|
+
model_config=model_config,
|
77
|
+
mem_fraction_static=server_args.mem_fraction_static,
|
78
|
+
gpu_id=tp_rank,
|
79
|
+
tp_rank=tp_rank,
|
80
|
+
tp_size=server_args.tp_size,
|
81
|
+
nccl_port=28888,
|
82
|
+
server_args=server_args,
|
83
|
+
)
|
84
|
+
print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
85
|
+
tokenizer = get_tokenizer(
|
86
|
+
server_args.tokenizer_path,
|
87
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
88
|
+
trust_remote_code=server_args.trust_remote_code,
|
89
|
+
)
|
90
|
+
if server_args.tp_size > 1:
|
91
|
+
dist.barrier()
|
92
|
+
return model_runner, tokenizer
|
93
|
+
|
94
|
+
|
95
|
+
def prepare_inputs(bench_args, tokenizer):
|
96
|
+
prompts = [
|
97
|
+
"The capital of France is",
|
98
|
+
"The capital of the United Kindom is",
|
99
|
+
"Today is a sunny day and I like",
|
100
|
+
]
|
101
|
+
input_ids = [tokenizer.encode(p) for p in prompts]
|
102
|
+
sampling_params = SamplingParams(
|
103
|
+
temperature=0,
|
104
|
+
max_new_tokens=BenchArgs.output_len,
|
105
|
+
)
|
106
|
+
|
107
|
+
reqs = []
|
108
|
+
for i in range(len(prompts)):
|
109
|
+
assert len(input_ids[i]) > bench_args.cut_len
|
110
|
+
|
111
|
+
tmp_input_ids = input_ids[i][:bench_args.cut_len]
|
112
|
+
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
113
|
+
req.prefix_indices = []
|
114
|
+
req.sampling_params = sampling_params
|
115
|
+
req.input_ids = req.origin_input_ids
|
116
|
+
reqs.append(req)
|
117
|
+
|
118
|
+
return input_ids, reqs
|
119
|
+
|
120
|
+
|
121
|
+
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
122
|
+
for i in range(len(reqs)):
|
123
|
+
req = reqs[i]
|
124
|
+
req.input_ids += input_ids[i][bench_args.cut_len:]
|
125
|
+
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
126
|
+
i, :bench_args.cut_len
|
127
|
+
]
|
128
|
+
return reqs
|
129
|
+
|
130
|
+
|
131
|
+
def prepare_synthetic_inputs(bench_args, tokenizer):
|
132
|
+
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
|
133
|
+
sampling_params = SamplingParams(
|
134
|
+
temperature=0,
|
135
|
+
max_new_tokens=BenchArgs.output_len,
|
136
|
+
)
|
137
|
+
|
138
|
+
reqs = []
|
139
|
+
for i in range(len(input_ids)):
|
140
|
+
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
141
|
+
req.prefix_indices = []
|
142
|
+
req.sampling_params = sampling_params
|
143
|
+
req.input_ids = req.origin_input_ids
|
144
|
+
reqs.append(req)
|
145
|
+
|
146
|
+
return reqs
|
147
|
+
|
148
|
+
|
149
|
+
def extend(reqs, model_runner):
|
150
|
+
batch = Batch.init_new(
|
151
|
+
reqs=reqs,
|
152
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
153
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
154
|
+
tree_cache=None)
|
155
|
+
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
156
|
+
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
157
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
158
|
+
return next_token_ids, output.next_token_logits, batch
|
159
|
+
|
160
|
+
|
161
|
+
def decode(input_token_ids, batch, model_runner):
|
162
|
+
batch.prepare_for_decode(input_token_ids.cpu().numpy())
|
163
|
+
output = model_runner.forward(batch, ForwardMode.DECODE)
|
164
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
165
|
+
return next_token_ids, output.next_token_logits
|
166
|
+
|
167
|
+
|
168
|
+
def correctness_test(
|
169
|
+
server_args,
|
170
|
+
bench_args,
|
171
|
+
tp_rank,
|
172
|
+
):
|
173
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
174
|
+
|
175
|
+
# Load the model
|
176
|
+
model_runner, tokenizer = load_model(server_args, tp_rank)
|
177
|
+
|
178
|
+
# Prepare inputs
|
179
|
+
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
|
180
|
+
|
181
|
+
# Prefill
|
182
|
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
183
|
+
rank_print("prefill logits (first half)", next_token_logits)
|
184
|
+
|
185
|
+
# Prepare extend inputs
|
186
|
+
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
|
187
|
+
|
188
|
+
# Extend
|
189
|
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
190
|
+
rank_print("prefill logits (final)", next_token_logits)
|
191
|
+
|
192
|
+
# Decode
|
193
|
+
output_ids = [list(req.input_ids) for req in reqs]
|
194
|
+
for _ in range(bench_args.output_len):
|
195
|
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
196
|
+
for i in range(len(reqs)):
|
197
|
+
output_ids[i].append(next_token_ids[i])
|
198
|
+
|
199
|
+
# Print
|
200
|
+
for i in range(len(reqs)):
|
201
|
+
print(tokenizer.decode(output_ids[i]))
|
202
|
+
|
203
|
+
|
204
|
+
def latency_test(
|
205
|
+
server_args,
|
206
|
+
bench_args,
|
207
|
+
tp_rank,
|
208
|
+
):
|
209
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
210
|
+
|
211
|
+
# Load the model
|
212
|
+
model_runner, tokenizer = load_model(server_args, tp_rank)
|
213
|
+
print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}")
|
214
|
+
|
215
|
+
# Prepare inputs
|
216
|
+
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
|
217
|
+
|
218
|
+
def clear():
|
219
|
+
model_runner.req_to_token_pool.clear()
|
220
|
+
model_runner.token_to_kv_pool.clear()
|
221
|
+
|
222
|
+
@torch.inference_mode()
|
223
|
+
def run_once(output_len):
|
224
|
+
# Prefill
|
225
|
+
torch.cuda.synchronize()
|
226
|
+
tot_latency = 0
|
227
|
+
tic = time.time()
|
228
|
+
next_token_ids, _, batch = extend(reqs, model_runner)
|
229
|
+
torch.cuda.synchronize()
|
230
|
+
prefill_latency = time.time() - tic
|
231
|
+
tot_latency += prefill_latency
|
232
|
+
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
233
|
+
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
234
|
+
|
235
|
+
# Decode
|
236
|
+
for i in range(output_len):
|
237
|
+
torch.cuda.synchronize()
|
238
|
+
tic = time.time()
|
239
|
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
240
|
+
torch.cuda.synchronize()
|
241
|
+
latency = time.time() - tic
|
242
|
+
tot_latency += latency
|
243
|
+
throughput = bench_args.batch_size / latency
|
244
|
+
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
245
|
+
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
246
|
+
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
247
|
+
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
|
248
|
+
|
249
|
+
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
|
250
|
+
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
|
251
|
+
|
252
|
+
# Warm up
|
253
|
+
run_once(4)
|
254
|
+
clear()
|
255
|
+
|
256
|
+
# Run again
|
257
|
+
run_once(bench_args.output_len)
|
258
|
+
|
259
|
+
|
260
|
+
def main(server_args, bench_args):
|
261
|
+
print(bench_args)
|
262
|
+
|
263
|
+
if bench_args.correctness_test:
|
264
|
+
work_func = correctness_test
|
265
|
+
else:
|
266
|
+
work_func = latency_test
|
267
|
+
|
268
|
+
workers = []
|
269
|
+
for tp_rank in range(server_args.tp_size):
|
270
|
+
proc = multiprocessing.Process(
|
271
|
+
target=work_func,
|
272
|
+
args=(
|
273
|
+
server_args,
|
274
|
+
bench_args,
|
275
|
+
tp_rank,
|
276
|
+
),
|
277
|
+
)
|
278
|
+
proc.start()
|
279
|
+
workers.append(proc)
|
280
|
+
|
281
|
+
for proc in workers:
|
282
|
+
proc.join()
|
283
|
+
|
284
|
+
|
285
|
+
if __name__ == "__main__":
|
286
|
+
parser = argparse.ArgumentParser()
|
287
|
+
ServerArgs.add_cli_args(parser)
|
288
|
+
BenchArgs.add_cli_args(parser)
|
289
|
+
args = parser.parse_args()
|
290
|
+
|
291
|
+
server_args = ServerArgs.from_cli_args(args)
|
292
|
+
bench_args = BenchArgs.from_cli_args(args)
|
293
|
+
|
294
|
+
logging.basicConfig(
|
295
|
+
level=getattr(logging, server_args.log_level.upper()),
|
296
|
+
format="%(message)s",
|
297
|
+
)
|
298
|
+
|
299
|
+
main(server_args, bench_args)
|
sglang/global_config.py
CHANGED
@@ -26,7 +26,17 @@ class GlobalConfig:
|
|
26
26
|
self.concate_and_append_mode = "no_adjust"
|
27
27
|
|
28
28
|
# Request dependency time due to network delay
|
29
|
-
self.
|
30
|
-
|
29
|
+
self.request_dependency_delay = 0.02
|
30
|
+
self.wait_for_new_request_delay = 0.0006
|
31
|
+
|
32
|
+
# New generation token ratio estimation
|
33
|
+
self.base_new_token_ratio = 0.4
|
34
|
+
self.base_min_new_token_ratio = 0.2
|
35
|
+
self.new_token_ratio_decay = 0.0001
|
36
|
+
self.new_token_ratio_recovery = 0.05
|
37
|
+
|
38
|
+
# The threshold (number of tokens) to trigger layer-wise cuda sync.
|
39
|
+
# This can improve the speed for large batch sizes during prefill.
|
40
|
+
self.layer_sync_threshold = 8192
|
31
41
|
|
32
42
|
global_config = GlobalConfig()
|
sglang/lang/compiler.py
CHANGED
@@ -4,7 +4,7 @@ from queue import Queue
|
|
4
4
|
from typing import List, Union
|
5
5
|
|
6
6
|
from sglang.global_config import global_config
|
7
|
-
from sglang.lang.interpreter import ProgramState, StreamExecutor,
|
7
|
+
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
8
8
|
from sglang.lang.ir import (
|
9
9
|
SglArgument,
|
10
10
|
SglConstantText,
|
@@ -184,7 +184,7 @@ class CompiledFunction:
|
|
184
184
|
|
185
185
|
# Extract prefix by tracing and cache it
|
186
186
|
if len(batch_kwargs) > 1:
|
187
|
-
|
187
|
+
cache_program(self.function, backend)
|
188
188
|
|
189
189
|
# Run all programs
|
190
190
|
if num_threads == "auto":
|
sglang/lang/interpreter.py
CHANGED
@@ -6,6 +6,7 @@ import multiprocessing
|
|
6
6
|
import queue
|
7
7
|
import threading
|
8
8
|
import uuid
|
9
|
+
import warnings
|
9
10
|
from concurrent.futures import ThreadPoolExecutor
|
10
11
|
from contextlib import contextmanager
|
11
12
|
from typing import Any, Callable, Dict, List, Optional, Union
|
@@ -30,7 +31,11 @@ from sglang.lang.ir import (
|
|
30
31
|
SglVarScopeEnd,
|
31
32
|
SglVideo,
|
32
33
|
)
|
33
|
-
from sglang.utils import
|
34
|
+
from sglang.utils import (
|
35
|
+
encode_image_base64,
|
36
|
+
encode_video_base64,
|
37
|
+
get_exception_traceback,
|
38
|
+
)
|
34
39
|
|
35
40
|
|
36
41
|
def run_internal(state, program, func_args, func_kwargs, sync):
|
@@ -61,7 +66,7 @@ def run_program(
|
|
61
66
|
default_sampling_para,
|
62
67
|
chat_template=None,
|
63
68
|
stream=stream,
|
64
|
-
|
69
|
+
num_api_spec_tokens=program.num_api_spec_tokens,
|
65
70
|
)
|
66
71
|
state = ProgramState(stream_executor)
|
67
72
|
|
@@ -173,7 +178,7 @@ class StreamExecutor:
|
|
173
178
|
default_sampling_para,
|
174
179
|
chat_template,
|
175
180
|
stream,
|
176
|
-
|
181
|
+
num_api_spec_tokens=None,
|
177
182
|
use_thread=True,
|
178
183
|
):
|
179
184
|
self.sid = uuid.uuid4().hex
|
@@ -181,20 +186,16 @@ class StreamExecutor:
|
|
181
186
|
self.arguments: Dict[str, Any] = arguments
|
182
187
|
self.default_sampling_para = default_sampling_para
|
183
188
|
self.stream = stream
|
184
|
-
self.api_num_spec_tokens = api_num_spec_tokens
|
185
189
|
|
186
190
|
self.variables = {} # Dict[name: str -> value: str]
|
187
191
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
188
192
|
self.meta_info = {} # Dict[name: str -> info: str]
|
189
193
|
self.is_finished = False
|
190
|
-
self.
|
194
|
+
self.error_ = None
|
191
195
|
|
192
196
|
# For completion
|
193
197
|
self.text_ = "" # The full text
|
194
198
|
|
195
|
-
# For speculative execution
|
196
|
-
self.speculated_text = ""
|
197
|
-
|
198
199
|
# For chat
|
199
200
|
self.messages_ = [] # The messages in the OpenAI API format
|
200
201
|
self.chat_template = chat_template or self.backend.get_chat_template()
|
@@ -208,6 +209,10 @@ class StreamExecutor:
|
|
208
209
|
# For fork/join
|
209
210
|
self.fork_start_text_pos = None
|
210
211
|
|
212
|
+
# For speculative execution
|
213
|
+
self.num_api_spec_tokens = num_api_spec_tokens
|
214
|
+
self.speculated_text = ""
|
215
|
+
|
211
216
|
# Worker thread
|
212
217
|
self.use_thread = use_thread
|
213
218
|
if self.use_thread:
|
@@ -286,6 +291,8 @@ class StreamExecutor:
|
|
286
291
|
exes[i].fork_start_text_pos = len(self.text_)
|
287
292
|
exes[i].images_ = list(self.images_)
|
288
293
|
|
294
|
+
# TODO(ying): handle API speculative execution
|
295
|
+
|
289
296
|
return exes
|
290
297
|
|
291
298
|
def text(self):
|
@@ -296,6 +303,10 @@ class StreamExecutor:
|
|
296
303
|
self.sync()
|
297
304
|
return self.messages_
|
298
305
|
|
306
|
+
def error(self):
|
307
|
+
self.sync()
|
308
|
+
return self.error_
|
309
|
+
|
299
310
|
def end(self):
|
300
311
|
if self.use_thread:
|
301
312
|
if self.worker.is_alive():
|
@@ -314,7 +325,7 @@ class StreamExecutor:
|
|
314
325
|
try:
|
315
326
|
self._execute(expr)
|
316
327
|
except Exception as e:
|
317
|
-
|
328
|
+
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
|
318
329
|
error = e
|
319
330
|
break
|
320
331
|
self.queue.task_done()
|
@@ -334,7 +345,7 @@ class StreamExecutor:
|
|
334
345
|
if self.stream_var_event:
|
335
346
|
for name in self.stream_var_event:
|
336
347
|
self.stream_var_event[name].set()
|
337
|
-
self.
|
348
|
+
self.error_ = error
|
338
349
|
|
339
350
|
if self.stream_text_event:
|
340
351
|
self.stream_text_event.set()
|
@@ -383,12 +394,23 @@ class StreamExecutor:
|
|
383
394
|
else:
|
384
395
|
raise ValueError(f"Unknown type: {type(other)}")
|
385
396
|
|
386
|
-
def _execute_fill(self, value: str):
|
397
|
+
def _execute_fill(self, value: str, prefix=False):
|
387
398
|
value = str(value)
|
399
|
+
|
400
|
+
if (
|
401
|
+
self.cur_role == "assistant"
|
402
|
+
and self.num_api_spec_tokens is not None
|
403
|
+
and self.backend.is_chat_model
|
404
|
+
and not prefix
|
405
|
+
):
|
406
|
+
self.backend.spec_fill(value)
|
407
|
+
return
|
408
|
+
|
388
409
|
if self.speculated_text.startswith(value):
|
389
410
|
self.speculated_text = self.speculated_text[len(value) :]
|
390
411
|
else:
|
391
412
|
self.speculated_text = ""
|
413
|
+
|
392
414
|
self.text_ += value
|
393
415
|
|
394
416
|
def _execute_image(self, expr: SglImage):
|
@@ -413,65 +435,80 @@ class StreamExecutor:
|
|
413
435
|
# if global_config.eager_fill_image:
|
414
436
|
# self.backend.fill_image(self)
|
415
437
|
|
438
|
+
def _spec_gen(self, sampling_params):
|
439
|
+
stop = sampling_params.stop
|
440
|
+
max_new_tokens = sampling_params.max_new_tokens
|
441
|
+
meta_info = {}
|
442
|
+
|
443
|
+
def regen():
|
444
|
+
nonlocal meta_info
|
445
|
+
|
446
|
+
sampling_params.max_new_tokens = max(
|
447
|
+
sampling_params.max_new_tokens, self.num_api_spec_tokens
|
448
|
+
)
|
449
|
+
sampling_params.stop = None
|
450
|
+
self.speculated_text, meta_info = self.backend.generate(
|
451
|
+
self, sampling_params=sampling_params
|
452
|
+
)
|
453
|
+
|
454
|
+
def find_stop():
|
455
|
+
if isinstance(stop, str):
|
456
|
+
return self.speculated_text.find(stop)
|
457
|
+
elif isinstance(stop, (tuple, list)):
|
458
|
+
pos = -1
|
459
|
+
for stop_str in stop:
|
460
|
+
stop_pos = self.speculated_text.find(stop_str)
|
461
|
+
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
462
|
+
pos = stop_pos
|
463
|
+
return pos
|
464
|
+
else:
|
465
|
+
raise Exception("Wrong type of stop in sampling parameters.")
|
466
|
+
|
467
|
+
if stop is None:
|
468
|
+
if len(self.speculated_text) < max_new_tokens:
|
469
|
+
regen()
|
470
|
+
comp = self.speculated_text[:max_new_tokens]
|
471
|
+
self.speculated_text = self.speculated_text[max_new_tokens:]
|
472
|
+
elif isinstance(stop, (str, list, tuple)):
|
473
|
+
if self.speculated_text == "":
|
474
|
+
regen()
|
475
|
+
stop_pos = find_stop()
|
476
|
+
if stop_pos == -1:
|
477
|
+
stop_pos = min(
|
478
|
+
sampling_params.max_new_tokens,
|
479
|
+
len(self.speculated_text),
|
480
|
+
)
|
481
|
+
comp = self.speculated_text[:stop_pos]
|
482
|
+
self.speculated_text = self.speculated_text[stop_pos:]
|
483
|
+
else:
|
484
|
+
raise ValueError("Wrong type of stop in sampling parameters.")
|
485
|
+
|
486
|
+
return comp, meta_info
|
487
|
+
|
416
488
|
def _execute_gen(self, expr: SglGen):
|
417
489
|
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
418
490
|
name = expr.name
|
419
491
|
|
420
492
|
if not self.stream:
|
421
|
-
if self.
|
422
|
-
stop = sampling_params.stop
|
423
|
-
max_new_tokens = sampling_params.max_new_tokens
|
424
|
-
meta_info = {}
|
425
|
-
|
426
|
-
def regen():
|
427
|
-
sampling_params.max_new_tokens = max(
|
428
|
-
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
429
|
-
)
|
430
|
-
sampling_params.stop = None
|
431
|
-
self.speculated_text, meta_info = self.backend.generate(
|
432
|
-
self, sampling_params=sampling_params
|
433
|
-
)
|
434
|
-
|
435
|
-
def find_stop():
|
436
|
-
if isinstance(stop, str):
|
437
|
-
return self.speculated_text.find(stop), len(stop)
|
438
|
-
elif isinstance(stop, (tuple, list)):
|
439
|
-
pos = -1
|
440
|
-
stop_len = 0
|
441
|
-
for stop_str in stop:
|
442
|
-
stop_pos = self.speculated_text.find(stop_str)
|
443
|
-
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
444
|
-
pos = stop_pos
|
445
|
-
stop_len = len(stop_str)
|
446
|
-
return pos, stop_len
|
447
|
-
else:
|
448
|
-
raise Exception("Wrong type of stop in sampling parameters.")
|
449
|
-
|
450
|
-
if stop is None:
|
451
|
-
if len(self.speculated_text) < max_new_tokens:
|
452
|
-
regen()
|
453
|
-
comp = self.speculated_text[:max_new_tokens]
|
454
|
-
self.speculated_text = self.speculated_text[max_new_tokens:]
|
455
|
-
elif isinstance(stop, (str, list, tuple)):
|
456
|
-
if self.speculated_text == "":
|
457
|
-
regen()
|
458
|
-
stop_pos, stop_len = find_stop()
|
459
|
-
if stop_pos == -1:
|
460
|
-
stop_pos, stop_len = (
|
461
|
-
min(
|
462
|
-
sampling_params.max_new_tokens,
|
463
|
-
len(self.speculated_text),
|
464
|
-
),
|
465
|
-
0,
|
466
|
-
)
|
467
|
-
comp = self.speculated_text[:stop_pos]
|
468
|
-
self.speculated_text = self.speculated_text[stop_pos:]
|
469
|
-
else:
|
470
|
-
raise ValueError("Wrong type of stop in sampling parameters.")
|
471
|
-
else:
|
493
|
+
if self.num_api_spec_tokens is None:
|
472
494
|
comp, meta_info = self.backend.generate(
|
473
|
-
self,
|
495
|
+
self,
|
496
|
+
sampling_params=sampling_params,
|
474
497
|
)
|
498
|
+
else:
|
499
|
+
if self.backend.is_chat_model:
|
500
|
+
# Speculative execution on models with only chat interface.
|
501
|
+
# Store the calls into a temporary list.
|
502
|
+
# They will be lazily executed later.
|
503
|
+
comp, meta_info = self.backend.generate(
|
504
|
+
self,
|
505
|
+
sampling_params=sampling_params,
|
506
|
+
spec_var_name=name,
|
507
|
+
)
|
508
|
+
return
|
509
|
+
|
510
|
+
else: # Speculative execution on models with completion interface
|
511
|
+
comp, meta_info = self._spec_gen(sampling_params)
|
475
512
|
|
476
513
|
self.text_ += comp
|
477
514
|
|
@@ -479,6 +516,9 @@ class StreamExecutor:
|
|
479
516
|
self.meta_info[name] = meta_info
|
480
517
|
self.variable_event[name].set()
|
481
518
|
else:
|
519
|
+
assert (
|
520
|
+
self.num_api_spec_tokens is None
|
521
|
+
), "stream is not supported with api speculative execution"
|
482
522
|
generator = self.backend.generate_stream(
|
483
523
|
self, sampling_params=sampling_params
|
484
524
|
)
|
@@ -534,10 +574,19 @@ class StreamExecutor:
|
|
534
574
|
|
535
575
|
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
536
576
|
|
537
|
-
self._execute_fill(prefix)
|
577
|
+
self._execute_fill(prefix, prefix=True)
|
538
578
|
self.cur_role_begin_pos = len(self.text_)
|
539
579
|
|
540
580
|
def _execute_role_end(self, expr: SglRoleEnd):
|
581
|
+
if (
|
582
|
+
self.cur_role == "assistant"
|
583
|
+
and self.num_api_spec_tokens is not None
|
584
|
+
and self.backend.is_chat_model
|
585
|
+
):
|
586
|
+
# Execute the stored lazy generation calls
|
587
|
+
self.backend.role_end_generate(self)
|
588
|
+
self.cur_role = None
|
589
|
+
|
541
590
|
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
542
591
|
|
543
592
|
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
@@ -564,8 +613,6 @@ class StreamExecutor:
|
|
564
613
|
# OpenAI chat API format
|
565
614
|
self.messages_.append({"role": expr.role, "content": new_text})
|
566
615
|
|
567
|
-
self.cur_role = None
|
568
|
-
|
569
616
|
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
570
617
|
self.variables[expr.name] = int(len(self.text_))
|
571
618
|
|
@@ -709,7 +756,7 @@ class ProgramState:
|
|
709
756
|
return self.stream_executor.sync()
|
710
757
|
|
711
758
|
def error(self):
|
712
|
-
return self.stream_executor.error
|
759
|
+
return self.stream_executor.error()
|
713
760
|
|
714
761
|
def text_iter(self, var_name: Optional[str] = None):
|
715
762
|
if self.stream_executor.stream:
|