sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.19"
|
2
2
|
|
3
3
|
# SGL API Components
|
4
4
|
from sglang.api import (
|
@@ -24,10 +24,10 @@ from sglang.api import (
|
|
24
24
|
|
25
25
|
# SGL Backends
|
26
26
|
from sglang.backend.anthropic import Anthropic
|
27
|
+
from sglang.backend.litellm import LiteLLM
|
27
28
|
from sglang.backend.openai import OpenAI
|
28
29
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
29
30
|
from sglang.backend.vertexai import VertexAI
|
30
|
-
from sglang.backend.litellm import LiteLLM
|
31
31
|
|
32
32
|
# Global Configurations
|
33
33
|
from sglang.global_config import global_config
|
sglang/api.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
"""
|
1
|
+
"""Public APIs of the language."""
|
2
2
|
|
3
3
|
import os
|
4
4
|
import re
|
@@ -43,14 +43,14 @@ def set_default_backend(backend: BaseBackend):
|
|
43
43
|
global_config.default_backend = backend
|
44
44
|
|
45
45
|
|
46
|
-
def flush_cache(backend: BaseBackend = None):
|
46
|
+
def flush_cache(backend: Optional[BaseBackend] = None):
|
47
47
|
backend = backend or global_config.default_backend
|
48
48
|
if backend is None:
|
49
49
|
return False
|
50
50
|
return backend.flush_cache()
|
51
51
|
|
52
52
|
|
53
|
-
def get_server_args(backend: BaseBackend = None):
|
53
|
+
def get_server_args(backend: Optional[BaseBackend] = None):
|
54
54
|
backend = backend or global_config.default_backend
|
55
55
|
if backend is None:
|
56
56
|
return None
|
@@ -67,10 +67,16 @@ def gen(
|
|
67
67
|
frequency_penalty: Optional[float] = None,
|
68
68
|
presence_penalty: Optional[float] = None,
|
69
69
|
ignore_eos: Optional[bool] = None,
|
70
|
+
return_logprob: Optional[bool] = None,
|
71
|
+
logprob_start_len: Optional[int] = None,
|
72
|
+
top_logprobs_num: Optional[int] = None,
|
73
|
+
return_text_in_logprobs: Optional[bool] = None,
|
70
74
|
dtype: Optional[type] = None,
|
71
75
|
choices: Optional[List[str]] = None,
|
72
76
|
regex: Optional[str] = None,
|
73
77
|
):
|
78
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
79
|
+
|
74
80
|
if choices:
|
75
81
|
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
|
76
82
|
|
@@ -91,6 +97,10 @@ def gen(
|
|
91
97
|
frequency_penalty,
|
92
98
|
presence_penalty,
|
93
99
|
ignore_eos,
|
100
|
+
return_logprob,
|
101
|
+
logprob_start_len,
|
102
|
+
top_logprobs_num,
|
103
|
+
return_text_in_logprobs,
|
94
104
|
dtype,
|
95
105
|
regex,
|
96
106
|
)
|
@@ -106,6 +116,10 @@ def gen_int(
|
|
106
116
|
frequency_penalty: Optional[float] = None,
|
107
117
|
presence_penalty: Optional[float] = None,
|
108
118
|
ignore_eos: Optional[bool] = None,
|
119
|
+
return_logprob: Optional[bool] = None,
|
120
|
+
logprob_start_len: Optional[int] = None,
|
121
|
+
top_logprobs_num: Optional[int] = None,
|
122
|
+
return_text_in_logprobs: Optional[bool] = None,
|
109
123
|
):
|
110
124
|
return SglGen(
|
111
125
|
name,
|
@@ -117,6 +131,10 @@ def gen_int(
|
|
117
131
|
frequency_penalty,
|
118
132
|
presence_penalty,
|
119
133
|
ignore_eos,
|
134
|
+
return_logprob,
|
135
|
+
logprob_start_len,
|
136
|
+
top_logprobs_num,
|
137
|
+
return_text_in_logprobs,
|
120
138
|
int,
|
121
139
|
None,
|
122
140
|
)
|
@@ -132,6 +150,10 @@ def gen_string(
|
|
132
150
|
frequency_penalty: Optional[float] = None,
|
133
151
|
presence_penalty: Optional[float] = None,
|
134
152
|
ignore_eos: Optional[bool] = None,
|
153
|
+
return_logprob: Optional[bool] = None,
|
154
|
+
logprob_start_len: Optional[int] = None,
|
155
|
+
top_logprobs_num: Optional[int] = None,
|
156
|
+
return_text_in_logprobs: Optional[bool] = None,
|
135
157
|
):
|
136
158
|
return SglGen(
|
137
159
|
name,
|
@@ -143,6 +165,10 @@ def gen_string(
|
|
143
165
|
frequency_penalty,
|
144
166
|
presence_penalty,
|
145
167
|
ignore_eos,
|
168
|
+
return_logprob,
|
169
|
+
logprob_start_len,
|
170
|
+
top_logprobs_num,
|
171
|
+
return_text_in_logprobs,
|
146
172
|
str,
|
147
173
|
None,
|
148
174
|
)
|
@@ -158,7 +184,7 @@ def video(path: str, num_frames: int):
|
|
158
184
|
|
159
185
|
def select(
|
160
186
|
name: Optional[str] = None,
|
161
|
-
choices: List[str] = None,
|
187
|
+
choices: Optional[List[str]] = None,
|
162
188
|
temperature: float = 0.0,
|
163
189
|
):
|
164
190
|
assert choices is not None
|
sglang/backend/litellm.py
CHANGED
@@ -13,7 +13,6 @@ except ImportError as e:
|
|
13
13
|
|
14
14
|
|
15
15
|
class LiteLLM(BaseBackend):
|
16
|
-
|
17
16
|
def __init__(
|
18
17
|
self,
|
19
18
|
model_name,
|
@@ -33,7 +32,8 @@ class LiteLLM(BaseBackend):
|
|
33
32
|
self.model_name = model_name
|
34
33
|
|
35
34
|
self.chat_template = chat_template or get_chat_template_by_model_path(
|
36
|
-
model_name
|
35
|
+
model_name
|
36
|
+
)
|
37
37
|
|
38
38
|
self.client_params = {
|
39
39
|
"api_key": api_key,
|
sglang/backend/openai.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import logging
|
2
3
|
import time
|
3
4
|
import warnings
|
4
|
-
import dataclasses
|
5
5
|
from typing import Callable, List, Optional, Union
|
6
6
|
|
7
7
|
import numpy as np
|
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
|
|
105
105
|
def get_chat_template(self):
|
106
106
|
return self.chat_template
|
107
107
|
|
108
|
-
def _prepare_spec_execution(
|
109
|
-
|
108
|
+
def _prepare_spec_execution(
|
109
|
+
self,
|
110
|
+
sampling_params: SglSamplingParams,
|
111
|
+
num_api_spec_tokens: int,
|
112
|
+
spec_var_name: str,
|
113
|
+
):
|
110
114
|
if "max_tokens" not in self.spec_kwargs:
|
111
115
|
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
112
116
|
else:
|
113
|
-
assert
|
114
|
-
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
115
|
-
)
|
117
|
+
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
116
118
|
|
117
119
|
params = sampling_params.to_openai_kwargs()
|
118
120
|
for key, value in params.items():
|
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
|
|
151
153
|
)
|
152
154
|
prompt = s.messages_
|
153
155
|
else:
|
154
|
-
return self._prepare_spec_execution(
|
155
|
-
s.num_api_spec_tokens, spec_var_name
|
156
|
+
return self._prepare_spec_execution(
|
157
|
+
sampling_params, s.num_api_spec_tokens, spec_var_name
|
158
|
+
)
|
156
159
|
else:
|
157
160
|
prompt = s.text_
|
158
161
|
|
@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
|
|
325
328
|
ret_str = ret.choices[0].text
|
326
329
|
ret_token = self.tokenizer.encode(ret_str)[0]
|
327
330
|
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
328
|
-
self.token_usage.completion_tokens= ret.usage.completion_tokens
|
331
|
+
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
329
332
|
|
330
333
|
# TODO:
|
331
334
|
# 1. return logits as the scores
|
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
|
|
355
358
|
return decision, scores, None, None
|
356
359
|
|
357
360
|
|
358
|
-
def openai_completion(
|
361
|
+
def openai_completion(
|
362
|
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
363
|
+
):
|
359
364
|
for attempt in range(retries):
|
360
365
|
try:
|
361
366
|
if is_chat:
|
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
|
|
385
390
|
return comp
|
386
391
|
|
387
392
|
|
388
|
-
def openai_completion_stream(
|
393
|
+
def openai_completion_stream(
|
394
|
+
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
395
|
+
):
|
389
396
|
for attempt in range(retries):
|
390
397
|
try:
|
391
398
|
if is_chat:
|
392
399
|
if "stop" in kwargs and kwargs["stop"] is None:
|
393
400
|
kwargs.pop("stop")
|
394
401
|
generator = client.chat.completions.create(
|
395
|
-
messages=prompt,
|
396
|
-
|
402
|
+
messages=prompt,
|
403
|
+
stream=True,
|
404
|
+
stream_options={"include_usage": True},
|
405
|
+
**kwargs,
|
397
406
|
)
|
398
407
|
for ret in generator:
|
399
408
|
if len(ret.choices) == 0:
|
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
|
|
405
414
|
yield content or "", {}
|
406
415
|
else:
|
407
416
|
generator = client.completions.create(
|
408
|
-
prompt=prompt,
|
409
|
-
|
417
|
+
prompt=prompt,
|
418
|
+
stream=True,
|
419
|
+
stream_options={"include_usage": True},
|
420
|
+
**kwargs,
|
410
421
|
)
|
411
422
|
for ret in generator:
|
412
423
|
if len(ret.choices) == 0:
|
@@ -1,18 +1,18 @@
|
|
1
1
|
import json
|
2
|
-
from typing import
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
-
import requests
|
6
5
|
|
7
6
|
from sglang.backend.base_backend import BaseBackend
|
8
7
|
from sglang.global_config import global_config
|
9
8
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
10
9
|
from sglang.lang.interpreter import StreamExecutor
|
11
|
-
from sglang.lang.ir import
|
12
|
-
from sglang.utils import
|
10
|
+
from sglang.lang.ir import SglSamplingParams
|
11
|
+
from sglang.utils import http_request
|
13
12
|
|
14
13
|
|
15
14
|
class RuntimeEndpoint(BaseBackend):
|
15
|
+
|
16
16
|
def __init__(
|
17
17
|
self,
|
18
18
|
base_url: str,
|
@@ -38,8 +38,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
38
38
|
self.model_info = res.json()
|
39
39
|
|
40
40
|
self.chat_template = get_chat_template_by_model_path(
|
41
|
-
self.model_info["model_path"]
|
42
|
-
)
|
41
|
+
self.model_info["model_path"])
|
43
42
|
|
44
43
|
def get_model_name(self):
|
45
44
|
return self.model_info["model_path"]
|
@@ -125,6 +124,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
125
124
|
else:
|
126
125
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
127
126
|
|
127
|
+
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
128
|
+
value = getattr(sampling_params, item, None)
|
129
|
+
if value is not None:
|
130
|
+
data[item] = value
|
131
|
+
|
128
132
|
self._add_images(s, data)
|
129
133
|
|
130
134
|
res = http_request(
|
@@ -167,6 +171,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
167
171
|
else:
|
168
172
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
169
173
|
|
174
|
+
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
175
|
+
value = getattr(sampling_params, item, None)
|
176
|
+
if value is not None:
|
177
|
+
data[item] = value
|
178
|
+
|
170
179
|
data["stream"] = True
|
171
180
|
self._add_images(s, data)
|
172
181
|
|
@@ -181,21 +190,16 @@ class RuntimeEndpoint(BaseBackend):
|
|
181
190
|
self._assert_success(res)
|
182
191
|
pos = 0
|
183
192
|
|
184
|
-
incomplete_text = ""
|
185
193
|
for chunk in res.iter_lines(decode_unicode=False):
|
186
194
|
chunk = chunk.decode("utf-8")
|
187
195
|
if chunk and chunk.startswith("data:"):
|
188
196
|
if chunk == "data: [DONE]":
|
189
197
|
break
|
190
198
|
data = json.loads(chunk[5:].strip("\n"))
|
191
|
-
|
199
|
+
chunk_text = data["text"][pos:]
|
192
200
|
meta_info = data["meta_info"]
|
193
|
-
pos += len(
|
194
|
-
|
195
|
-
yield text, meta_info
|
196
|
-
|
197
|
-
if len(incomplete_text) > 0:
|
198
|
-
yield incomplete_text, meta_info
|
201
|
+
pos += len(chunk_text)
|
202
|
+
yield chunk_text, meta_info
|
199
203
|
|
200
204
|
def select(
|
201
205
|
self,
|
sglang/bench_latency.py
ADDED
@@ -0,0 +1,317 @@
|
|
1
|
+
"""
|
2
|
+
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
3
|
+
|
4
|
+
# Usage (latency test):
|
5
|
+
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
6
|
+
|
7
|
+
# Usage (correctness test):
|
8
|
+
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
9
|
+
|
10
|
+
### Reference output:
|
11
|
+
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
12
|
+
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
13
|
+
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
14
|
+
device='cuda:0', dtype=torch.float16)
|
15
|
+
prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
|
16
|
+
[-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
|
17
|
+
[-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
|
18
|
+
device='cuda:0', dtype=torch.float16)
|
19
|
+
<s> The capital of France is.
|
20
|
+
The capital of the United States is Washington, D.C.
|
21
|
+
|
22
|
+
<s> The capital of the United Kindom is.
|
23
|
+
The capital of the United Kingdom is London.
|
24
|
+
The capital of the
|
25
|
+
<s> Today is a sunny day and I like go for a walk in the park.
|
26
|
+
I'm going to the park
|
27
|
+
"""
|
28
|
+
|
29
|
+
import argparse
|
30
|
+
import dataclasses
|
31
|
+
import logging
|
32
|
+
import multiprocessing
|
33
|
+
import time
|
34
|
+
|
35
|
+
import numpy as np
|
36
|
+
import torch
|
37
|
+
import torch.distributed as dist
|
38
|
+
|
39
|
+
from sglang.srt.hf_transformers_utils import get_tokenizer
|
40
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
|
41
|
+
from sglang.srt.managers.controller.model_runner import ModelRunner
|
42
|
+
from sglang.srt.model_config import ModelConfig
|
43
|
+
from sglang.srt.sampling_params import SamplingParams
|
44
|
+
from sglang.srt.server_args import ServerArgs
|
45
|
+
from sglang.srt.utils import suppress_other_loggers
|
46
|
+
|
47
|
+
|
48
|
+
@dataclasses.dataclass
|
49
|
+
class BenchArgs:
|
50
|
+
batch_size: int = 1
|
51
|
+
input_len: int = 1024
|
52
|
+
output_len: int = 4
|
53
|
+
correctness_test: bool = False
|
54
|
+
# This is only used for correctness test
|
55
|
+
cut_len: int = 4
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
59
|
+
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
|
60
|
+
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
|
61
|
+
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
|
62
|
+
parser.add_argument("--correctness-test", action="store_true")
|
63
|
+
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
67
|
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
68
|
+
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
69
|
+
|
70
|
+
|
71
|
+
def load_model(server_args, tp_rank):
|
72
|
+
suppress_other_loggers()
|
73
|
+
|
74
|
+
model_config = ModelConfig(path=server_args.model_path)
|
75
|
+
model_runner = ModelRunner(
|
76
|
+
model_config=model_config,
|
77
|
+
mem_fraction_static=server_args.mem_fraction_static,
|
78
|
+
gpu_id=tp_rank,
|
79
|
+
tp_rank=tp_rank,
|
80
|
+
tp_size=server_args.tp_size,
|
81
|
+
nccl_port=28888,
|
82
|
+
server_args=server_args,
|
83
|
+
)
|
84
|
+
print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
85
|
+
tokenizer = get_tokenizer(
|
86
|
+
server_args.tokenizer_path,
|
87
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
88
|
+
trust_remote_code=server_args.trust_remote_code,
|
89
|
+
)
|
90
|
+
if server_args.tp_size > 1:
|
91
|
+
dist.barrier()
|
92
|
+
return model_runner, tokenizer
|
93
|
+
|
94
|
+
|
95
|
+
def prepare_inputs(bench_args, tokenizer):
|
96
|
+
prompts = [
|
97
|
+
"The capital of France is",
|
98
|
+
"The capital of the United Kindom is",
|
99
|
+
"Today is a sunny day and I like",
|
100
|
+
]
|
101
|
+
input_ids = [tokenizer.encode(p) for p in prompts]
|
102
|
+
sampling_params = SamplingParams(
|
103
|
+
temperature=0,
|
104
|
+
max_new_tokens=BenchArgs.output_len,
|
105
|
+
)
|
106
|
+
|
107
|
+
reqs = []
|
108
|
+
for i in range(len(prompts)):
|
109
|
+
assert len(input_ids[i]) > bench_args.cut_len
|
110
|
+
|
111
|
+
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
112
|
+
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
113
|
+
req.prefix_indices = []
|
114
|
+
req.sampling_params = sampling_params
|
115
|
+
req.input_ids = req.origin_input_ids
|
116
|
+
reqs.append(req)
|
117
|
+
|
118
|
+
return input_ids, reqs
|
119
|
+
|
120
|
+
|
121
|
+
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
122
|
+
for i in range(len(reqs)):
|
123
|
+
req = reqs[i]
|
124
|
+
req.input_ids += input_ids[i][bench_args.cut_len :]
|
125
|
+
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
126
|
+
i, : bench_args.cut_len
|
127
|
+
]
|
128
|
+
return reqs
|
129
|
+
|
130
|
+
|
131
|
+
def prepare_synthetic_inputs(bench_args, tokenizer):
|
132
|
+
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
|
133
|
+
sampling_params = SamplingParams(
|
134
|
+
temperature=0,
|
135
|
+
max_new_tokens=BenchArgs.output_len,
|
136
|
+
)
|
137
|
+
|
138
|
+
reqs = []
|
139
|
+
for i in range(len(input_ids)):
|
140
|
+
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
141
|
+
req.prefix_indices = []
|
142
|
+
req.sampling_params = sampling_params
|
143
|
+
req.input_ids = req.origin_input_ids
|
144
|
+
reqs.append(req)
|
145
|
+
|
146
|
+
return reqs
|
147
|
+
|
148
|
+
|
149
|
+
def extend(reqs, model_runner):
|
150
|
+
batch = Batch.init_new(
|
151
|
+
reqs=reqs,
|
152
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
153
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
154
|
+
tree_cache=None,
|
155
|
+
)
|
156
|
+
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
157
|
+
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
158
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
159
|
+
return next_token_ids, output.next_token_logits, batch
|
160
|
+
|
161
|
+
|
162
|
+
def decode(input_token_ids, batch, model_runner):
|
163
|
+
batch.prepare_for_decode(input_token_ids.cpu().numpy())
|
164
|
+
output = model_runner.forward(batch, ForwardMode.DECODE)
|
165
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
166
|
+
return next_token_ids, output.next_token_logits
|
167
|
+
|
168
|
+
|
169
|
+
@torch.inference_mode()
|
170
|
+
def correctness_test(
|
171
|
+
server_args,
|
172
|
+
bench_args,
|
173
|
+
tp_rank,
|
174
|
+
):
|
175
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
176
|
+
|
177
|
+
# Load the model
|
178
|
+
model_runner, tokenizer = load_model(server_args, tp_rank)
|
179
|
+
|
180
|
+
# Prepare inputs
|
181
|
+
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
|
182
|
+
|
183
|
+
if bench_args.cut_len > 0:
|
184
|
+
# Prefill
|
185
|
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
186
|
+
rank_print("prefill logits (first half)", next_token_logits)
|
187
|
+
|
188
|
+
# Prepare extend inputs
|
189
|
+
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
|
190
|
+
|
191
|
+
# Extend
|
192
|
+
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
193
|
+
rank_print("prefill logits (final)", next_token_logits)
|
194
|
+
|
195
|
+
# Decode
|
196
|
+
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
197
|
+
for _ in range(bench_args.output_len):
|
198
|
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
199
|
+
for i in range(len(reqs)):
|
200
|
+
output_ids[i].append(next_token_ids[i])
|
201
|
+
|
202
|
+
# Print
|
203
|
+
for i in range(len(reqs)):
|
204
|
+
print(tokenizer.decode(output_ids[i]))
|
205
|
+
|
206
|
+
|
207
|
+
def latency_test(
|
208
|
+
server_args,
|
209
|
+
bench_args,
|
210
|
+
tp_rank,
|
211
|
+
):
|
212
|
+
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
213
|
+
|
214
|
+
# Load the model
|
215
|
+
model_runner, tokenizer = load_model(server_args, tp_rank)
|
216
|
+
print(
|
217
|
+
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
218
|
+
)
|
219
|
+
|
220
|
+
# Prepare inputs
|
221
|
+
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
|
222
|
+
|
223
|
+
def clear():
|
224
|
+
model_runner.req_to_token_pool.clear()
|
225
|
+
model_runner.token_to_kv_pool.clear()
|
226
|
+
|
227
|
+
@torch.inference_mode()
|
228
|
+
def run_once(output_len):
|
229
|
+
# Prefill
|
230
|
+
torch.cuda.synchronize()
|
231
|
+
tot_latency = 0
|
232
|
+
tic = time.time()
|
233
|
+
next_token_ids, _, batch = extend(reqs, model_runner)
|
234
|
+
torch.cuda.synchronize()
|
235
|
+
prefill_latency = time.time() - tic
|
236
|
+
tot_latency += prefill_latency
|
237
|
+
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
238
|
+
rank_print(
|
239
|
+
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
240
|
+
)
|
241
|
+
|
242
|
+
# Decode
|
243
|
+
for i in range(output_len):
|
244
|
+
torch.cuda.synchronize()
|
245
|
+
tic = time.time()
|
246
|
+
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
247
|
+
torch.cuda.synchronize()
|
248
|
+
latency = time.time() - tic
|
249
|
+
tot_latency += latency
|
250
|
+
throughput = bench_args.batch_size / latency
|
251
|
+
if i < 5:
|
252
|
+
rank_print(
|
253
|
+
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
254
|
+
)
|
255
|
+
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
256
|
+
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
257
|
+
rank_print(
|
258
|
+
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
259
|
+
)
|
260
|
+
|
261
|
+
throughput = (
|
262
|
+
(bench_args.input_len + bench_args.output_len)
|
263
|
+
* bench_args.batch_size
|
264
|
+
/ tot_latency
|
265
|
+
)
|
266
|
+
rank_print(
|
267
|
+
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
268
|
+
)
|
269
|
+
|
270
|
+
# Warm up
|
271
|
+
run_once(4)
|
272
|
+
clear()
|
273
|
+
|
274
|
+
# Run again
|
275
|
+
run_once(bench_args.output_len)
|
276
|
+
|
277
|
+
|
278
|
+
def main(server_args, bench_args):
|
279
|
+
print(bench_args)
|
280
|
+
|
281
|
+
if bench_args.correctness_test:
|
282
|
+
work_func = correctness_test
|
283
|
+
else:
|
284
|
+
work_func = latency_test
|
285
|
+
|
286
|
+
workers = []
|
287
|
+
for tp_rank in range(server_args.tp_size):
|
288
|
+
proc = multiprocessing.Process(
|
289
|
+
target=work_func,
|
290
|
+
args=(
|
291
|
+
server_args,
|
292
|
+
bench_args,
|
293
|
+
tp_rank,
|
294
|
+
),
|
295
|
+
)
|
296
|
+
proc.start()
|
297
|
+
workers.append(proc)
|
298
|
+
|
299
|
+
for proc in workers:
|
300
|
+
proc.join()
|
301
|
+
|
302
|
+
|
303
|
+
if __name__ == "__main__":
|
304
|
+
parser = argparse.ArgumentParser()
|
305
|
+
ServerArgs.add_cli_args(parser)
|
306
|
+
BenchArgs.add_cli_args(parser)
|
307
|
+
args = parser.parse_args()
|
308
|
+
|
309
|
+
server_args = ServerArgs.from_cli_args(args)
|
310
|
+
bench_args = BenchArgs.from_cli_args(args)
|
311
|
+
|
312
|
+
logging.basicConfig(
|
313
|
+
level=getattr(logging, server_args.log_level.upper()),
|
314
|
+
format="%(message)s",
|
315
|
+
)
|
316
|
+
|
317
|
+
main(server_args, bench_args)
|
sglang/global_config.py
CHANGED
@@ -27,7 +27,7 @@ class GlobalConfig:
|
|
27
27
|
|
28
28
|
# Request dependency time due to network delay
|
29
29
|
self.request_dependency_delay = 0.02
|
30
|
-
self.wait_for_new_request_delay = 0.
|
30
|
+
self.wait_for_new_request_delay = 0.0006
|
31
31
|
|
32
32
|
# New generation token ratio estimation
|
33
33
|
self.base_new_token_ratio = 0.4
|
@@ -35,5 +35,9 @@ class GlobalConfig:
|
|
35
35
|
self.new_token_ratio_decay = 0.0001
|
36
36
|
self.new_token_ratio_recovery = 0.05
|
37
37
|
|
38
|
+
# The threshold (number of tokens) to trigger layer-wise cuda sync.
|
39
|
+
# This can improve the speed for large batch sizes during prefill.
|
40
|
+
self.layer_sync_threshold = 8192
|
41
|
+
|
38
42
|
|
39
43
|
global_config = GlobalConfig()
|