sglang 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +55 -2
- sglang/api.py +3 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +1 -0
- sglang/lang/chat_template.py +74 -0
- sglang/lang/interpreter.py +40 -16
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +2 -1
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/router/infer_batch.py +70 -33
- sglang/srt/managers/router/manager.py +7 -2
- sglang/srt/managers/router/model_rpc.py +116 -73
- sglang/srt/managers/router/model_runner.py +111 -167
- sglang/srt/managers/router/radix_cache.py +46 -38
- sglang/srt/managers/tokenizer_manager.py +56 -11
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +7 -0
- sglang/srt/models/commandr.py +376 -0
- sglang/srt/models/dbrx.py +413 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +22 -20
- sglang/srt/models/llama2.py +23 -21
- sglang/srt/models/llava.py +12 -10
- sglang/srt/models/mixtral.py +27 -25
- sglang/srt/models/qwen.py +23 -21
- sglang/srt/models/qwen2.py +23 -21
- sglang/srt/models/stablelm.py +20 -21
- sglang/srt/models/yivl.py +6 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +68 -447
- sglang/srt/server_args.py +76 -49
- sglang/srt/utils.py +88 -32
- sglang/srt/weight_utils.py +402 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/METADATA +12 -14
- sglang-0.1.15.dist-info/RECORD +69 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
"""The arguments of the server."""
|
2
|
+
|
1
3
|
import argparse
|
2
4
|
import dataclasses
|
3
5
|
from typing import List, Optional, Union
|
@@ -5,34 +7,47 @@ from typing import List, Optional, Union
|
|
5
7
|
|
6
8
|
@dataclasses.dataclass
|
7
9
|
class ServerArgs:
|
10
|
+
# Model and tokenizer
|
8
11
|
model_path: str
|
9
12
|
tokenizer_path: Optional[str] = None
|
10
|
-
host: str = "127.0.0.1"
|
11
|
-
port: int = 30000
|
12
|
-
additional_ports: Optional[Union[List[int], int]] = None
|
13
13
|
load_format: str = "auto"
|
14
14
|
tokenizer_mode: str = "auto"
|
15
15
|
chat_template: Optional[str] = None
|
16
16
|
trust_remote_code: bool = True
|
17
|
+
context_length: Optional[int] = None
|
18
|
+
|
19
|
+
# Port
|
20
|
+
host: str = "127.0.0.1"
|
21
|
+
port: int = 30000
|
22
|
+
additional_ports: Optional[Union[List[int], int]] = None
|
23
|
+
|
24
|
+
# Memory and scheduling
|
17
25
|
mem_fraction_static: Optional[float] = None
|
18
26
|
max_prefill_num_token: Optional[int] = None
|
19
|
-
context_length: Optional[int] = None
|
20
|
-
tp_size: int = 1
|
21
27
|
schedule_heuristic: str = "lpm"
|
22
28
|
schedule_conservativeness: float = 1.0
|
23
|
-
|
24
|
-
|
29
|
+
|
30
|
+
# Other runtime options
|
31
|
+
tp_size: int = 1
|
25
32
|
stream_interval: int = 8
|
33
|
+
random_seed: int = 42
|
34
|
+
|
35
|
+
# Logging
|
36
|
+
log_level: str = "info"
|
37
|
+
log_requests: bool = False
|
26
38
|
disable_log_stats: bool = False
|
27
39
|
log_stats_interval: int = 10
|
28
|
-
|
40
|
+
show_time_cost: bool = False
|
29
41
|
|
30
|
-
#
|
31
|
-
|
42
|
+
# Other
|
43
|
+
api_key: str = ""
|
44
|
+
|
45
|
+
# Optimization/debug options
|
32
46
|
enable_flashinfer: bool = False
|
47
|
+
attention_reduce_in_fp32: bool = False
|
48
|
+
disable_radix_cache: bool = False
|
33
49
|
disable_regex_jump_forward: bool = False
|
34
50
|
disable_disk_cache: bool = False
|
35
|
-
api_key: str = ""
|
36
51
|
|
37
52
|
def __post_init__(self):
|
38
53
|
if self.tokenizer_path is None:
|
@@ -65,15 +80,16 @@ class ServerArgs:
|
|
65
80
|
default=ServerArgs.tokenizer_path,
|
66
81
|
help="The path of the tokenizer.",
|
67
82
|
)
|
68
|
-
parser.add_argument("--host", type=str, default=ServerArgs.host
|
69
|
-
|
70
|
-
|
83
|
+
parser.add_argument("--host", type=str, default=ServerArgs.host,
|
84
|
+
help="The host of the server.")
|
85
|
+
parser.add_argument("--port", type=int, default=ServerArgs.port,
|
86
|
+
help="The port of the server.")
|
71
87
|
parser.add_argument(
|
72
88
|
"--additional-ports",
|
73
89
|
type=int,
|
74
90
|
nargs="*",
|
75
91
|
default=[],
|
76
|
-
help="Additional ports specified for
|
92
|
+
help="Additional ports specified for the server.",
|
77
93
|
)
|
78
94
|
parser.add_argument(
|
79
95
|
"--load-format",
|
@@ -111,6 +127,12 @@ class ServerArgs:
|
|
111
127
|
action="store_true",
|
112
128
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
113
129
|
)
|
130
|
+
parser.add_argument(
|
131
|
+
"--context-length",
|
132
|
+
type=int,
|
133
|
+
default=ServerArgs.context_length,
|
134
|
+
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
135
|
+
)
|
114
136
|
parser.add_argument(
|
115
137
|
"--mem-fraction-static",
|
116
138
|
type=float,
|
@@ -123,18 +145,6 @@ class ServerArgs:
|
|
123
145
|
default=ServerArgs.max_prefill_num_token,
|
124
146
|
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
125
147
|
)
|
126
|
-
parser.add_argument(
|
127
|
-
"--context-length",
|
128
|
-
type=int,
|
129
|
-
default=ServerArgs.context_length,
|
130
|
-
help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
|
131
|
-
)
|
132
|
-
parser.add_argument(
|
133
|
-
"--tp-size",
|
134
|
-
type=int,
|
135
|
-
default=ServerArgs.tp_size,
|
136
|
-
help="Tensor parallelism degree.",
|
137
|
-
)
|
138
148
|
parser.add_argument(
|
139
149
|
"--schedule-heuristic",
|
140
150
|
type=str,
|
@@ -148,15 +158,10 @@ class ServerArgs:
|
|
148
158
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
149
159
|
)
|
150
160
|
parser.add_argument(
|
151
|
-
"--
|
161
|
+
"--tp-size",
|
152
162
|
type=int,
|
153
|
-
default=ServerArgs.
|
154
|
-
help="
|
155
|
-
)
|
156
|
-
parser.add_argument(
|
157
|
-
"--attention-reduce-in-fp32",
|
158
|
-
action="store_true",
|
159
|
-
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
163
|
+
default=ServerArgs.tp_size,
|
164
|
+
help="Tensor parallelism size.",
|
160
165
|
)
|
161
166
|
parser.add_argument(
|
162
167
|
"--stream-interval",
|
@@ -164,11 +169,22 @@ class ServerArgs:
|
|
164
169
|
default=ServerArgs.stream_interval,
|
165
170
|
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
166
171
|
)
|
172
|
+
parser.add_argument(
|
173
|
+
"--random-seed",
|
174
|
+
type=int,
|
175
|
+
default=ServerArgs.random_seed,
|
176
|
+
help="Random seed.",
|
177
|
+
)
|
167
178
|
parser.add_argument(
|
168
179
|
"--log-level",
|
169
180
|
type=str,
|
170
181
|
default=ServerArgs.log_level,
|
171
|
-
help="
|
182
|
+
help="Logging level",
|
183
|
+
)
|
184
|
+
parser.add_argument(
|
185
|
+
"--log-requests",
|
186
|
+
action="store_true",
|
187
|
+
help="Log all requests",
|
172
188
|
)
|
173
189
|
parser.add_argument(
|
174
190
|
"--disable-log-stats",
|
@@ -181,17 +197,34 @@ class ServerArgs:
|
|
181
197
|
default=ServerArgs.log_stats_interval,
|
182
198
|
help="Log stats interval in second.",
|
183
199
|
)
|
184
|
-
# optional modes
|
185
200
|
parser.add_argument(
|
186
|
-
"--
|
201
|
+
"--show-time-cost",
|
187
202
|
action="store_true",
|
188
|
-
help="
|
203
|
+
help="Show time cost of custom marks",
|
204
|
+
)
|
205
|
+
parser.add_argument(
|
206
|
+
"--api-key",
|
207
|
+
type=str,
|
208
|
+
default=ServerArgs.api_key,
|
209
|
+
help="Set API key of the server",
|
189
210
|
)
|
211
|
+
|
212
|
+
# Optimization/debug options
|
190
213
|
parser.add_argument(
|
191
214
|
"--enable-flashinfer",
|
192
215
|
action="store_true",
|
193
216
|
help="Enable flashinfer inference kernels",
|
194
217
|
)
|
218
|
+
parser.add_argument(
|
219
|
+
"--attention-reduce-in-fp32",
|
220
|
+
action="store_true",
|
221
|
+
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
222
|
+
)
|
223
|
+
parser.add_argument(
|
224
|
+
"--disable-radix-cache",
|
225
|
+
action="store_true",
|
226
|
+
help="Disable RadixAttention",
|
227
|
+
)
|
195
228
|
parser.add_argument(
|
196
229
|
"--disable-regex-jump-forward",
|
197
230
|
action="store_true",
|
@@ -202,12 +235,6 @@ class ServerArgs:
|
|
202
235
|
action="store_true",
|
203
236
|
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
204
237
|
)
|
205
|
-
parser.add_argument(
|
206
|
-
"--api-key",
|
207
|
-
type=str,
|
208
|
-
default=ServerArgs.api_key,
|
209
|
-
help="Set API Key",
|
210
|
-
)
|
211
238
|
|
212
239
|
@classmethod
|
213
240
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -217,13 +244,13 @@ class ServerArgs:
|
|
217
244
|
def url(self):
|
218
245
|
return f"http://{self.host}:{self.port}"
|
219
246
|
|
220
|
-
def
|
247
|
+
def print_mode_args(self):
|
221
248
|
return (
|
222
|
-
f"disable_radix_cache={self.disable_radix_cache}, "
|
223
249
|
f"enable_flashinfer={self.enable_flashinfer}, "
|
250
|
+
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
251
|
+
f"disable_radix_cache={self.disable_radix_cache}, "
|
224
252
|
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
225
253
|
f"disable_disk_cache={self.disable_disk_cache}, "
|
226
|
-
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
227
254
|
)
|
228
255
|
|
229
256
|
|
@@ -233,4 +260,4 @@ class PortArgs:
|
|
233
260
|
router_port: int
|
234
261
|
detokenizer_port: int
|
235
262
|
nccl_port: int
|
236
|
-
model_rpc_ports: List[int]
|
263
|
+
model_rpc_ports: List[int]
|
sglang/srt/utils.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
"""Common utilities."""
|
2
|
+
|
1
3
|
import base64
|
2
4
|
import os
|
3
5
|
import random
|
@@ -5,54 +7,68 @@ import socket
|
|
5
7
|
import sys
|
6
8
|
import time
|
7
9
|
import traceback
|
10
|
+
from importlib.metadata import PackageNotFoundError, version
|
8
11
|
from io import BytesIO
|
9
12
|
from typing import List, Optional
|
10
13
|
|
11
14
|
import numpy as np
|
15
|
+
import pydantic
|
12
16
|
import requests
|
13
17
|
import torch
|
14
|
-
|
18
|
+
from fastapi.responses import JSONResponse
|
19
|
+
from packaging import version as pkg_version
|
20
|
+
from pydantic import BaseModel
|
21
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
15
22
|
|
16
|
-
|
23
|
+
show_time_cost = False
|
24
|
+
time_infos = {}
|
17
25
|
|
18
26
|
|
19
|
-
def
|
20
|
-
|
21
|
-
|
22
|
-
if dist.get_rank() in [0, 1] and is_show_cost_time:
|
23
|
-
torch.cuda.synchronize()
|
24
|
-
start_time = time.time()
|
25
|
-
ans = func(*args, **kwargs)
|
26
|
-
torch.cuda.synchronize()
|
27
|
-
print(func_name, "cost time:", (time.time() - start_time) * 1000)
|
28
|
-
return ans
|
29
|
-
else:
|
30
|
-
torch.cuda.synchronize()
|
31
|
-
ans = func(*args, **kwargs)
|
32
|
-
torch.cuda.synchronize()
|
33
|
-
return ans
|
27
|
+
def enable_show_time_cost():
|
28
|
+
global show_time_cost
|
29
|
+
show_time_cost = True
|
34
30
|
|
35
|
-
return time_func
|
36
31
|
|
37
|
-
|
32
|
+
class TimeInfo:
|
33
|
+
def __init__(self, name, interval=0.1, color=0, indent=0):
|
34
|
+
self.name = name
|
35
|
+
self.interval = interval
|
36
|
+
self.color = color
|
37
|
+
self.indent = indent
|
38
38
|
|
39
|
+
self.acc_time = 0
|
40
|
+
self.last_acc_time = 0
|
41
|
+
|
42
|
+
def check(self):
|
43
|
+
if self.acc_time - self.last_acc_time > self.interval:
|
44
|
+
self.last_acc_time = self.acc_time
|
45
|
+
return True
|
46
|
+
return False
|
39
47
|
|
40
|
-
|
48
|
+
def pretty_print(self):
|
49
|
+
print(f"\x1b[{self.color}m", end="")
|
50
|
+
print("-" * self.indent * 2, end="")
|
51
|
+
print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
|
41
52
|
|
42
53
|
|
43
|
-
def mark_start(
|
54
|
+
def mark_start(name, interval=0.1, color=0, indent=0):
|
55
|
+
global time_infos, show_time_cost
|
56
|
+
if not show_time_cost:
|
57
|
+
return
|
44
58
|
torch.cuda.synchronize()
|
45
|
-
|
46
|
-
|
47
|
-
|
59
|
+
if time_infos.get(name, None) is None:
|
60
|
+
time_infos[name] = TimeInfo(name, interval, color, indent)
|
61
|
+
time_infos[name].acc_time -= time.time()
|
48
62
|
|
49
63
|
|
50
|
-
def mark_end(
|
64
|
+
def mark_end(name):
|
65
|
+
global time_infos, show_time_cost
|
66
|
+
if not show_time_cost:
|
67
|
+
return
|
51
68
|
torch.cuda.synchronize()
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
print(f"cost {key}:", cost_time)
|
69
|
+
time_infos[name].acc_time += time.time()
|
70
|
+
if time_infos[name].check():
|
71
|
+
time_infos[name].pretty_print()
|
56
72
|
|
57
73
|
|
58
74
|
def calculate_time(show=False, min_cost_ms=0.0):
|
@@ -110,7 +126,7 @@ def check_port(port):
|
|
110
126
|
return False
|
111
127
|
|
112
128
|
|
113
|
-
def
|
129
|
+
def allocate_init_ports(
|
114
130
|
port: Optional[int] = None,
|
115
131
|
additional_ports: Optional[List[int]] = None,
|
116
132
|
tp_size: int = 1,
|
@@ -149,8 +165,6 @@ def get_exception_traceback():
|
|
149
165
|
|
150
166
|
|
151
167
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
152
|
-
from transformers import LlamaTokenizer, LlamaTokenizerFast
|
153
|
-
|
154
168
|
# a bug when model's vocab size > tokenizer.vocab_size
|
155
169
|
vocab_size = tokenizer.vocab_size
|
156
170
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
@@ -259,3 +273,45 @@ def load_image(image_file):
|
|
259
273
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
260
274
|
|
261
275
|
return image
|
276
|
+
|
277
|
+
|
278
|
+
def assert_pkg_version(pkg: str, min_version: str):
|
279
|
+
try:
|
280
|
+
installed_version = version(pkg)
|
281
|
+
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
|
282
|
+
raise Exception(
|
283
|
+
f"{pkg} is installed with version {installed_version} which "
|
284
|
+
f"is less than the minimum required version {min_version}"
|
285
|
+
)
|
286
|
+
except PackageNotFoundError:
|
287
|
+
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
|
288
|
+
|
289
|
+
|
290
|
+
API_KEY_HEADER_NAME = "X-API-Key"
|
291
|
+
|
292
|
+
|
293
|
+
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
294
|
+
def __init__(self, app, api_key: str):
|
295
|
+
super().__init__(app)
|
296
|
+
self.api_key = api_key
|
297
|
+
|
298
|
+
async def dispatch(self, request, call_next):
|
299
|
+
# extract API key from the request headers
|
300
|
+
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
301
|
+
if not api_key_header or api_key_header != self.api_key:
|
302
|
+
return JSONResponse(
|
303
|
+
status_code=403,
|
304
|
+
content={"detail": "Invalid API Key"},
|
305
|
+
)
|
306
|
+
response = await call_next(request)
|
307
|
+
return response
|
308
|
+
|
309
|
+
|
310
|
+
# FIXME: Remove this once we drop support for pydantic 1.x
|
311
|
+
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
312
|
+
|
313
|
+
|
314
|
+
def jsonify_pydantic_model(obj: BaseModel):
|
315
|
+
if IS_PYDANTIC_1:
|
316
|
+
return obj.json(ensure_ascii=False)
|
317
|
+
return obj.model_dump_json()
|