sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,14 @@
|
|
1
1
|
import json
|
2
|
-
from typing import
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
-
|
5
|
+
|
6
6
|
from sglang.backend.base_backend import BaseBackend
|
7
7
|
from sglang.global_config import global_config
|
8
8
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
9
9
|
from sglang.lang.interpreter import StreamExecutor
|
10
|
-
from sglang.lang.ir import
|
11
|
-
from sglang.utils import
|
10
|
+
from sglang.lang.ir import SglSamplingParams
|
11
|
+
from sglang.utils import http_request
|
12
12
|
|
13
13
|
|
14
14
|
class RuntimeEndpoint(BaseBackend):
|
@@ -33,7 +33,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
33
33
|
api_key=self.api_key,
|
34
34
|
verify=self.verify,
|
35
35
|
)
|
36
|
-
|
36
|
+
self._assert_success(res)
|
37
37
|
self.model_info = res.json()
|
38
38
|
|
39
39
|
self.chat_template = get_chat_template_by_model_path(
|
@@ -49,7 +49,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
49
49
|
auth_token=self.auth_token,
|
50
50
|
verify=self.verify,
|
51
51
|
)
|
52
|
-
|
52
|
+
self._assert_success(res)
|
53
53
|
|
54
54
|
def get_server_args(self):
|
55
55
|
res = http_request(
|
@@ -57,6 +57,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
57
57
|
auth_token=self.auth_token,
|
58
58
|
verify=self.verify,
|
59
59
|
)
|
60
|
+
self._assert_success(res)
|
60
61
|
return res.json()
|
61
62
|
|
62
63
|
def get_chat_template(self):
|
@@ -70,17 +71,19 @@ class RuntimeEndpoint(BaseBackend):
|
|
70
71
|
api_key=self.api_key,
|
71
72
|
verify=self.verify,
|
72
73
|
)
|
73
|
-
|
74
|
+
self._assert_success(res)
|
74
75
|
|
75
76
|
def commit_lazy_operations(self, s: StreamExecutor):
|
77
|
+
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
78
|
+
self._add_images(s, data)
|
76
79
|
res = http_request(
|
77
80
|
self.base_url + "/generate",
|
78
|
-
json=
|
81
|
+
json=data,
|
79
82
|
auth_token=self.auth_token,
|
80
83
|
api_key=self.api_key,
|
81
84
|
verify=self.verify,
|
82
85
|
)
|
83
|
-
|
86
|
+
self._assert_success(res)
|
84
87
|
|
85
88
|
def fill_image(self, s: StreamExecutor):
|
86
89
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
@@ -92,7 +95,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
92
95
|
api_key=self.api_key,
|
93
96
|
verify=self.verify,
|
94
97
|
)
|
95
|
-
|
98
|
+
self._assert_success(res)
|
96
99
|
|
97
100
|
def generate(
|
98
101
|
self,
|
@@ -104,6 +107,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
104
107
|
"text": s.text_,
|
105
108
|
"sampling_params": {
|
106
109
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
110
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
107
111
|
**sampling_params.to_srt_kwargs(),
|
108
112
|
},
|
109
113
|
}
|
@@ -112,6 +116,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
112
116
|
"text": s.text_,
|
113
117
|
"sampling_params": {
|
114
118
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
119
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
115
120
|
"dtype": "int",
|
116
121
|
**sampling_params.to_srt_kwargs(),
|
117
122
|
},
|
@@ -119,6 +124,16 @@ class RuntimeEndpoint(BaseBackend):
|
|
119
124
|
else:
|
120
125
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
121
126
|
|
127
|
+
for item in [
|
128
|
+
"return_logprob",
|
129
|
+
"logprob_start_len",
|
130
|
+
"top_logprobs_num",
|
131
|
+
"return_text_in_logprobs",
|
132
|
+
]:
|
133
|
+
value = getattr(sampling_params, item, None)
|
134
|
+
if value is not None:
|
135
|
+
data[item] = value
|
136
|
+
|
122
137
|
self._add_images(s, data)
|
123
138
|
|
124
139
|
res = http_request(
|
@@ -128,6 +143,8 @@ class RuntimeEndpoint(BaseBackend):
|
|
128
143
|
api_key=self.api_key,
|
129
144
|
verify=self.verify,
|
130
145
|
)
|
146
|
+
self._assert_success(res)
|
147
|
+
|
131
148
|
obj = res.json()
|
132
149
|
comp = obj["text"]
|
133
150
|
return comp, obj["meta_info"]
|
@@ -142,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
142
159
|
"text": s.text_,
|
143
160
|
"sampling_params": {
|
144
161
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
162
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
145
163
|
**sampling_params.to_srt_kwargs(),
|
146
164
|
},
|
147
165
|
}
|
@@ -150,6 +168,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
150
168
|
"text": s.text_,
|
151
169
|
"sampling_params": {
|
152
170
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
171
|
+
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
153
172
|
"dtype": "int",
|
154
173
|
**sampling_params.to_srt_kwargs(),
|
155
174
|
},
|
@@ -157,10 +176,20 @@ class RuntimeEndpoint(BaseBackend):
|
|
157
176
|
else:
|
158
177
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
159
178
|
|
179
|
+
for item in [
|
180
|
+
"return_logprob",
|
181
|
+
"logprob_start_len",
|
182
|
+
"top_logprobs_num",
|
183
|
+
"return_text_in_logprobs",
|
184
|
+
]:
|
185
|
+
value = getattr(sampling_params, item, None)
|
186
|
+
if value is not None:
|
187
|
+
data[item] = value
|
188
|
+
|
160
189
|
data["stream"] = True
|
161
190
|
self._add_images(s, data)
|
162
191
|
|
163
|
-
|
192
|
+
res = http_request(
|
164
193
|
self.base_url + "/generate",
|
165
194
|
json=data,
|
166
195
|
stream=True,
|
@@ -168,23 +197,19 @@ class RuntimeEndpoint(BaseBackend):
|
|
168
197
|
api_key=self.api_key,
|
169
198
|
verify=self.verify,
|
170
199
|
)
|
200
|
+
self._assert_success(res)
|
171
201
|
pos = 0
|
172
202
|
|
173
|
-
|
174
|
-
for chunk in response.iter_lines(decode_unicode=False):
|
203
|
+
for chunk in res.iter_lines(decode_unicode=False):
|
175
204
|
chunk = chunk.decode("utf-8")
|
176
205
|
if chunk and chunk.startswith("data:"):
|
177
206
|
if chunk == "data: [DONE]":
|
178
207
|
break
|
179
208
|
data = json.loads(chunk[5:].strip("\n"))
|
180
|
-
|
209
|
+
chunk_text = data["text"][pos:]
|
181
210
|
meta_info = data["meta_info"]
|
182
|
-
pos += len(
|
183
|
-
|
184
|
-
yield text, meta_info
|
185
|
-
|
186
|
-
if len(incomplete_text) > 0:
|
187
|
-
yield incomplete_text, meta_info
|
211
|
+
pos += len(chunk_text)
|
212
|
+
yield chunk_text, meta_info
|
188
213
|
|
189
214
|
def select(
|
190
215
|
self,
|
@@ -204,7 +229,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
204
229
|
api_key=self.api_key,
|
205
230
|
verify=self.verify,
|
206
231
|
)
|
207
|
-
|
232
|
+
self._assert_success(res)
|
208
233
|
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
209
234
|
|
210
235
|
# Compute logprob
|
@@ -222,15 +247,21 @@ class RuntimeEndpoint(BaseBackend):
|
|
222
247
|
api_key=self.api_key,
|
223
248
|
verify=self.verify,
|
224
249
|
)
|
225
|
-
|
250
|
+
self._assert_success(res)
|
226
251
|
obj = res.json()
|
227
|
-
|
252
|
+
normalized_prompt_logprobs = [
|
228
253
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
229
254
|
]
|
230
|
-
|
255
|
+
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
256
|
+
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
|
257
|
+
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
|
231
258
|
|
232
|
-
|
233
|
-
|
259
|
+
return (
|
260
|
+
decision,
|
261
|
+
normalized_prompt_logprobs,
|
262
|
+
prefill_token_logprobs,
|
263
|
+
decode_token_logprobs,
|
264
|
+
)
|
234
265
|
|
235
266
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
236
267
|
res = http_request(
|
@@ -240,9 +271,13 @@ class RuntimeEndpoint(BaseBackend):
|
|
240
271
|
api_key=self.api_key,
|
241
272
|
verify=self.verify,
|
242
273
|
)
|
243
|
-
|
274
|
+
self._assert_success(res)
|
244
275
|
|
245
276
|
def _add_images(self, s: StreamExecutor, data):
|
246
277
|
if s.images_:
|
247
278
|
assert len(s.images_) == 1, "Only support one image."
|
248
279
|
data["image_data"] = s.images_[0][1]
|
280
|
+
|
281
|
+
def _assert_success(self, res):
|
282
|
+
if res.status_code != 200:
|
283
|
+
raise RuntimeError(res.json())
|
sglang/backend/vertexai.py
CHANGED
sglang/bench_latency.py
ADDED
@@ -0,0 +1,320 @@
|
|
1
|
+
"""
|
2
|
+
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
3
|
+
|
4
|
+
# Usage (latency test):
|
5
|
+
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
6
|
+
|
7
|
+
# Usage (correctness test):
|
8
|
+
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
9
|
+
|
10
|
+
### Reference output:
|
11
|
+
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
12
|
+
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
13
|
+
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
14
|
+
device='cuda:0', dtype=torch.float16)
|
15
|
+
prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
|
16
|
+
[-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
|
17
|
+
[-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
|
18
|
+
device='cuda:0', dtype=torch.float16)
|
19
|
+
<s> The capital of France is.
|
20
|
+
The capital of the United States is Washington, D.C.
|
21
|
+
|
22
|
+
<s> The capital of the United Kindom is.
|
23
|
+
The capital of the United Kingdom is London.
|
24
|
+
The capital of the
|
25
|
+
<s> Today is a sunny day and I like go for a walk in the park.
|
26
|
+
I'm going to the park
|
27
|
+
"""
|
28
|
+
|
29
|
+
import argparse
|
30
|
+
import dataclasses
|
31
|
+
import logging
|
32
|
+
import multiprocessing
|
33
|
+
import time
|
34
|
+
|
35
|
+
import numpy as np
|
36
|
+
import torch
|
37
|
+
import torch.distributed as dist
|
38
|
+
|
39
|
+
from sglang.srt.hf_transformers_utils import get_tokenizer
|
40
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
|
41
|
+
from sglang.srt.managers.controller.model_runner import ModelRunner
|
42
|
+
from sglang.srt.model_config import ModelConfig
|
43
|
+
from sglang.srt.sampling_params import SamplingParams
|
44
|
+
from sglang.srt.server_args import ServerArgs
|
45
|
+
from sglang.srt.utils import suppress_other_loggers
|
46
|
+
|
47
|
+
|
48
|
+
@dataclasses.dataclass
|
49
|
+
class BenchArgs:
|
50
|
+
batch_size: int = 1
|
51
|
+
input_len: int = 1024
|
52
|
+
output_len: int = 4
|
53
|
+
correctness_test: bool = False
|
54
|
+
# This is only used for correctness test
|
55
|
+
cut_len: int = 4
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
59
|
+
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
|
60
|
+
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
|
61
|
+
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
|
62
|
+
parser.add_argument("--correctness-test", action="store_true")
|
63
|
+
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
67
|
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
68
|
+
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
69
|
+
|
70
|
+
|
71
|
+
def load_model(server_args, tp_rank):
|
72
|
+
suppress_other_loggers()
|
73
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
74
|
+
|
75
|
+
model_config = ModelConfig(path=server_args.model_path)
|
76
|
+
model_runner = ModelRunner(
|
77
|
+
model_config=model_config,
|
78
|
+
mem_fraction_static=server_args.mem_fraction_static,
|
79
|
+
gpu_id=tp_rank,
|
80
|
+
tp_rank=tp_rank,
|
81
|
+
tp_size=server_args.tp_size,
|
82
|
+
nccl_port=28888,
|
83
|
+
server_args=server_args,
|
84
|
+
)
|
85
|
+
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
86
|
+
tokenizer = get_tokenizer(
|
87
|
+
server_args.tokenizer_path,
|
88
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
89
|
+
trust_remote_code=server_args.trust_remote_code,
|
90
|
+
)
|
91
|
+
if server_args.tp_size > 1:
|
92
|
+
dist.barrier()
|
93
|
+
return model_runner, tokenizer
|
94
|
+
|
95
|
+
|
96
|
+
def prepare_inputs(bench_args, tokenizer):
|
97
|
+
prompts = [
|
98
|
+
"The capital of France is",
|
99
|
+
"The capital of the United Kindom is",
|
100
|
+
"Today is a sunny day and I like",
|
101
|
+
]
|
102
|
+
input_ids = [tokenizer.encode(p) for p in prompts]
|
103
|
+
sampling_params = SamplingParams(
|
104
|
+
temperature=0,
|
105
|
+
max_new_tokens=BenchArgs.output_len,
|
106
|
+
)
|
107
|
+
|
108
|
+
reqs = []
|
109
|
+
for i in range(len(prompts)):
|
110
|
+
assert len(input_ids[i]) > bench_args.cut_len
|
111
|
+
|
112
|
+
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
113
|
+
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
114
|
+
req.prefix_indices = []
|
115
|
+
req.sampling_params = sampling_params
|
116
|
+
req.input_ids = req.origin_input_ids
|
117
|
+
reqs.append(req)
|
118
|
+
|
119
|
+
return input_ids, reqs
|
120
|
+
|
121
|
+
|
122
|
+
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
123
|
+
for i in range(len(reqs)):
|
124
|
+
req = reqs[i]
|
125
|
+
req.input_ids += input_ids[i][bench_args.cut_len :]
|
126
|
+
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
127
|
+
i, : bench_args.cut_len
|
128
|
+
]
|
129
|
+
return reqs
|
130
|
+
|
131
|
+
|
132
|
+
def prepare_synthetic_inputs(bench_args, tokenizer):
|
133
|
+
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
|
134
|
+
sampling_params = SamplingParams(
|
135
|
+
temperature=0,
|
136
|
+
max_new_tokens=BenchArgs.output_len,
|
137
|
+
)
|
138
|
+
|
139
|
+
reqs = []
|
140
|
+
for i in range(len(input_ids)):
|
141
|
+
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
142
|
+
req.prefix_indices = []
|
143
|
+
req.sampling_params = sampling_params
|
144
|
+
req.input_ids = req.origin_input_ids
|
145
|
+
reqs.append(req)
|
146
|
+
|
147
|
+
return reqs
|
148
|
+
|
149
|
+
|
150
|
+
def extend(reqs, model_runner):
|
151
|
+
batch = Batch.init_new(
|
152
|
+
reqs=reqs,
|
153
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
154
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
155
|
+
tree_cache=None,
|
156
|
+
)
|
157
|
+
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
158
|
+
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
159
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
160
|
+
return next_token_ids, output.next_token_logits, batch
|
161
|
+
|
162
|
+
|
163
|
+
def decode(input_token_ids, batch, model_runner):
|
164
|
+
batch.prepare_for_decode(input_token_ids.cpu().numpy())
|
165
|
+
output = model_runner.forward(batch, ForwardMode.DECODE)
|
166
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
167
|
+
return next_token_ids, output.next_token_logits
|
168
|
+
|
169
|
+
|
170
|
+
@torch.inference_mode()
|
171
|
+
def correctness_test(
|
172
|
+
server_args,
|
173
|
+
bench_args,
|
174
|
+
tp_rank,
|
175
|
+
):
|
176
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
177
|
+
|
178
|
+
# Load the model
|
179
|
+
model_runner, tokenizer = load_model(server_args, tp_rank)
|
180
|
+
|
181
|
+
# Prepare inputs
|
182
|
+
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
|
183
|
+
|
184
|
+
if bench_args.cut_len > 0:
|
185
|
+
# Prefill
|
186
|
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
187
|
+
rank_print("prefill logits (first half)", next_token_logits)
|
188
|
+
|
189
|
+
# Prepare extend inputs
|
190
|
+
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
|
191
|
+
|
192
|
+
# Extend
|
193
|
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
194
|
+
rank_print("prefill logits (final)", next_token_logits)
|
195
|
+
|
196
|
+
# Decode
|
197
|
+
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
198
|
+
for _ in range(bench_args.output_len):
|
199
|
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
200
|
+
for i in range(len(reqs)):
|
201
|
+
output_ids[i].append(next_token_ids[i])
|
202
|
+
|
203
|
+
# Print
|
204
|
+
for i in range(len(reqs)):
|
205
|
+
rank_print(tokenizer.decode(output_ids[i]))
|
206
|
+
|
207
|
+
|
208
|
+
def latency_test(
|
209
|
+
server_args,
|
210
|
+
bench_args,
|
211
|
+
tp_rank,
|
212
|
+
):
|
213
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
214
|
+
|
215
|
+
# Load the model
|
216
|
+
model_runner, tokenizer = load_model(server_args, tp_rank)
|
217
|
+
rank_print(
|
218
|
+
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
219
|
+
)
|
220
|
+
|
221
|
+
# Prepare inputs
|
222
|
+
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
|
223
|
+
|
224
|
+
def clear():
|
225
|
+
model_runner.req_to_token_pool.clear()
|
226
|
+
model_runner.token_to_kv_pool.clear()
|
227
|
+
|
228
|
+
@torch.inference_mode()
|
229
|
+
def run_once(output_len):
|
230
|
+
# Prefill
|
231
|
+
torch.cuda.synchronize()
|
232
|
+
tot_latency = 0
|
233
|
+
tic = time.time()
|
234
|
+
next_token_ids, _, batch = extend(reqs, model_runner)
|
235
|
+
torch.cuda.synchronize()
|
236
|
+
prefill_latency = time.time() - tic
|
237
|
+
tot_latency += prefill_latency
|
238
|
+
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
239
|
+
rank_print(
|
240
|
+
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
241
|
+
)
|
242
|
+
|
243
|
+
# Decode
|
244
|
+
for i in range(output_len):
|
245
|
+
torch.cuda.synchronize()
|
246
|
+
tic = time.time()
|
247
|
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
248
|
+
torch.cuda.synchronize()
|
249
|
+
latency = time.time() - tic
|
250
|
+
tot_latency += latency
|
251
|
+
throughput = bench_args.batch_size / latency
|
252
|
+
if i < 5:
|
253
|
+
rank_print(
|
254
|
+
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
255
|
+
)
|
256
|
+
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
257
|
+
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
258
|
+
rank_print(
|
259
|
+
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
260
|
+
)
|
261
|
+
|
262
|
+
throughput = (
|
263
|
+
(bench_args.input_len + bench_args.output_len)
|
264
|
+
* bench_args.batch_size
|
265
|
+
/ tot_latency
|
266
|
+
)
|
267
|
+
rank_print(
|
268
|
+
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
269
|
+
)
|
270
|
+
|
271
|
+
# Warm up
|
272
|
+
run_once(4)
|
273
|
+
clear()
|
274
|
+
|
275
|
+
# Run again
|
276
|
+
run_once(bench_args.output_len)
|
277
|
+
|
278
|
+
|
279
|
+
def main(server_args, bench_args):
|
280
|
+
print(bench_args)
|
281
|
+
|
282
|
+
if bench_args.correctness_test:
|
283
|
+
work_func = correctness_test
|
284
|
+
else:
|
285
|
+
work_func = latency_test
|
286
|
+
|
287
|
+
workers = []
|
288
|
+
for tp_rank in range(server_args.tp_size):
|
289
|
+
proc = multiprocessing.Process(
|
290
|
+
target=work_func,
|
291
|
+
args=(
|
292
|
+
server_args,
|
293
|
+
bench_args,
|
294
|
+
tp_rank,
|
295
|
+
),
|
296
|
+
)
|
297
|
+
proc.start()
|
298
|
+
workers.append(proc)
|
299
|
+
|
300
|
+
for proc in workers:
|
301
|
+
proc.join()
|
302
|
+
|
303
|
+
proc.terminate()
|
304
|
+
|
305
|
+
|
306
|
+
if __name__ == "__main__":
|
307
|
+
parser = argparse.ArgumentParser()
|
308
|
+
ServerArgs.add_cli_args(parser)
|
309
|
+
BenchArgs.add_cli_args(parser)
|
310
|
+
args = parser.parse_args()
|
311
|
+
|
312
|
+
server_args = ServerArgs.from_cli_args(args)
|
313
|
+
bench_args = BenchArgs.from_cli_args(args)
|
314
|
+
|
315
|
+
logging.basicConfig(
|
316
|
+
level=getattr(logging, server_args.log_level.upper()),
|
317
|
+
format="%(message)s",
|
318
|
+
)
|
319
|
+
|
320
|
+
main(server_args, bench_args)
|
sglang/global_config.py
CHANGED
@@ -8,17 +8,38 @@ class GlobalConfig:
|
|
8
8
|
# 2: output final text after every run
|
9
9
|
self.verbosity = 0
|
10
10
|
|
11
|
+
# Default backend of the language
|
11
12
|
self.default_backend = None
|
12
13
|
|
13
|
-
#
|
14
|
+
# Runtime constants: Request dependency time due to network delay
|
15
|
+
self.request_dependency_delay = 0.02
|
16
|
+
self.wait_for_new_request_delay = 0.0006
|
17
|
+
|
18
|
+
# Runtime constants: New generation token ratio estimation
|
19
|
+
self.base_new_token_ratio = 0.4
|
20
|
+
self.base_min_new_token_ratio = 0.2
|
21
|
+
self.new_token_ratio_decay = 0.0001
|
22
|
+
self.new_token_ratio_recovery = 0.05
|
23
|
+
|
24
|
+
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
25
|
+
# This can improve the speed for large batch sizes during prefill.
|
26
|
+
self.layer_sync_threshold = 8192
|
27
|
+
|
28
|
+
# Runtime constants: others
|
29
|
+
self.num_continue_decode_steps = 10
|
30
|
+
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
31
|
+
|
32
|
+
# Output tokenization configs
|
14
33
|
self.skip_special_tokens_in_output = True
|
34
|
+
self.spaces_between_special_tokens_in_out = True
|
15
35
|
|
16
|
-
#
|
36
|
+
# Interpreter optimization configs
|
17
37
|
self.eager_fill_image = False
|
18
|
-
self.
|
38
|
+
self.enable_precache_with_tracing = True
|
19
39
|
self.enable_parallel_encoding = True
|
20
40
|
self.enable_parallel_decoding = True
|
21
41
|
|
42
|
+
# Deprecated
|
22
43
|
# Choices: ["no_adjust", "adjust_cache"]
|
23
44
|
# no_adjust: Do not adjust the position embedding of KV cache.
|
24
45
|
# adjust_cache: Adjust the position embedding of KV cache.
|