sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -9,7 +9,8 @@ import os
|
|
9
9
|
import sys
|
10
10
|
import threading
|
11
11
|
import time
|
12
|
-
from
|
12
|
+
from http import HTTPStatus
|
13
|
+
from typing import Optional
|
13
14
|
|
14
15
|
# Fix a bug of Python threading
|
15
16
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -20,26 +21,30 @@ import requests
|
|
20
21
|
import uvicorn
|
21
22
|
import uvloop
|
22
23
|
from fastapi import FastAPI, Request
|
23
|
-
from fastapi.responses import Response, StreamingResponse
|
24
|
+
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
24
25
|
|
25
26
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
26
27
|
from sglang.srt.constrained import disable_cache
|
27
28
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
28
29
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
29
30
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
30
|
-
from sglang.srt.managers.
|
31
|
+
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
|
32
|
+
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
|
31
33
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
32
34
|
from sglang.srt.openai_api_adapter import (
|
33
|
-
|
34
|
-
|
35
|
+
load_chat_template_for_openai_api,
|
36
|
+
v1_chat_completions,
|
37
|
+
v1_completions,
|
38
|
+
)
|
39
|
+
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
35
40
|
from sglang.srt.utils import (
|
41
|
+
API_KEY_HEADER_NAME,
|
42
|
+
APIKeyValidatorMiddleware,
|
36
43
|
allocate_init_ports,
|
37
44
|
assert_pkg_version,
|
38
45
|
enable_show_time_cost,
|
39
|
-
get_exception_traceback,
|
40
|
-
API_KEY_HEADER_NAME,
|
41
|
-
APIKeyValidatorMiddleware
|
42
46
|
)
|
47
|
+
from sglang.utils import get_exception_traceback
|
43
48
|
|
44
49
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
45
50
|
|
@@ -69,7 +74,7 @@ async def get_server_args():
|
|
69
74
|
|
70
75
|
@app.get("/flush_cache")
|
71
76
|
async def flush_cache():
|
72
|
-
|
77
|
+
tokenizer_manager.flush_cache()
|
73
78
|
return Response(
|
74
79
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
75
80
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
@@ -77,21 +82,32 @@ async def flush_cache():
|
|
77
82
|
)
|
78
83
|
|
79
84
|
|
80
|
-
|
81
|
-
async def generate_request(obj: GenerateReqInput):
|
82
|
-
obj.post_init()
|
83
|
-
|
85
|
+
async def generate_request(obj: GenerateReqInput, request: Request):
|
84
86
|
if obj.stream:
|
85
87
|
|
86
88
|
async def stream_results():
|
87
|
-
|
89
|
+
try:
|
90
|
+
async for out in tokenizer_manager.generate_request(obj, request):
|
91
|
+
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
92
|
+
except ValueError as e:
|
93
|
+
out = {"error": {"message": str(e)}}
|
88
94
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
89
95
|
yield "data: [DONE]\n\n"
|
90
96
|
|
91
|
-
return StreamingResponse(stream_results(), media_type="text/event-stream"
|
97
|
+
return StreamingResponse(stream_results(), media_type="text/event-stream",
|
98
|
+
background=tokenizer_manager.create_abort_task(obj))
|
99
|
+
else:
|
100
|
+
try:
|
101
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
102
|
+
return ret
|
103
|
+
except ValueError as e:
|
104
|
+
return JSONResponse(
|
105
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
106
|
+
)
|
107
|
+
|
92
108
|
|
93
|
-
|
94
|
-
|
109
|
+
app.post("/generate")(generate_request)
|
110
|
+
app.put("/generate")(generate_request)
|
95
111
|
|
96
112
|
|
97
113
|
@app.post("/v1/completions")
|
@@ -104,7 +120,7 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
104
120
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
105
121
|
|
106
122
|
|
107
|
-
def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
123
|
+
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
|
108
124
|
global tokenizer_manager
|
109
125
|
|
110
126
|
logging.basicConfig(
|
@@ -126,28 +142,42 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|
126
142
|
|
127
143
|
# Allocate ports
|
128
144
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
129
|
-
server_args.port,
|
145
|
+
server_args.port,
|
146
|
+
server_args.additional_ports,
|
147
|
+
server_args.tp_size,
|
148
|
+
server_args.dp_size,
|
130
149
|
)
|
150
|
+
|
151
|
+
# Init local models port args
|
152
|
+
ports = server_args.additional_ports
|
153
|
+
tp = server_args.tp_size
|
154
|
+
model_port_args = []
|
155
|
+
for i in range(server_args.dp_size):
|
156
|
+
model_port_args.append(
|
157
|
+
ModelPortArgs(
|
158
|
+
nccl_port=ports[3 + i * (tp + 1)],
|
159
|
+
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
|
160
|
+
)
|
161
|
+
)
|
131
162
|
port_args = PortArgs(
|
132
|
-
tokenizer_port=
|
133
|
-
router_port=
|
134
|
-
detokenizer_port=
|
135
|
-
|
136
|
-
model_rpc_ports=server_args.additional_ports[4:],
|
163
|
+
tokenizer_port=ports[0],
|
164
|
+
router_port=ports[1],
|
165
|
+
detokenizer_port=ports[2],
|
166
|
+
model_port_args=model_port_args,
|
137
167
|
)
|
138
168
|
|
139
169
|
# Launch processes
|
140
|
-
tokenizer_manager = TokenizerManager(server_args, port_args)
|
170
|
+
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
141
171
|
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
142
172
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
143
173
|
|
174
|
+
if server_args.dp_size == 1:
|
175
|
+
start_process = start_controller_process_single
|
176
|
+
else:
|
177
|
+
start_process = start_controller_process_multi
|
144
178
|
proc_router = mp.Process(
|
145
|
-
target=
|
146
|
-
args=(
|
147
|
-
server_args,
|
148
|
-
port_args,
|
149
|
-
pipe_router_writer,
|
150
|
-
),
|
179
|
+
target=start_process,
|
180
|
+
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
151
181
|
)
|
152
182
|
proc_router.start()
|
153
183
|
proc_detoken = mp.Process(
|
@@ -167,14 +197,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|
167
197
|
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
168
198
|
proc_router.kill()
|
169
199
|
proc_detoken.kill()
|
170
|
-
print(
|
171
|
-
|
200
|
+
print(
|
201
|
+
f"Initialization failed. router_init_state: {router_init_state}", flush=True
|
202
|
+
)
|
203
|
+
print(
|
204
|
+
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
205
|
+
flush=True,
|
206
|
+
)
|
172
207
|
sys.exit(1)
|
173
208
|
assert proc_router.is_alive() and proc_detoken.is_alive()
|
174
209
|
|
175
210
|
if server_args.api_key and server_args.api_key != "":
|
176
211
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
177
212
|
|
213
|
+
# Send a warmup request
|
178
214
|
def _wait_and_warmup():
|
179
215
|
headers = {}
|
180
216
|
url = server_args.url()
|
@@ -192,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|
192
228
|
|
193
229
|
# Send a warmup request
|
194
230
|
try:
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
"
|
201
|
-
|
231
|
+
for _ in range(server_args.dp_size):
|
232
|
+
res = requests.post(
|
233
|
+
url + "/generate",
|
234
|
+
json={
|
235
|
+
"text": "The capital city of France is",
|
236
|
+
"sampling_params": {
|
237
|
+
"temperature": 0,
|
238
|
+
"max_new_tokens": 16,
|
239
|
+
},
|
202
240
|
},
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
except Exception as e:
|
241
|
+
headers=headers,
|
242
|
+
timeout=600,
|
243
|
+
)
|
244
|
+
assert res.status_code == 200
|
245
|
+
except Exception:
|
209
246
|
if pipe_finish_writer is not None:
|
210
247
|
pipe_finish_writer.send(get_exception_traceback())
|
211
248
|
print(f"Initialization failed. warmup error: {e}")
|
@@ -216,6 +253,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|
216
253
|
|
217
254
|
t = threading.Thread(target=_wait_and_warmup)
|
218
255
|
t.start()
|
256
|
+
|
257
|
+
# Listen for requests
|
219
258
|
try:
|
220
259
|
uvicorn.run(
|
221
260
|
app,
|
@@ -232,16 +271,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|
232
271
|
class Runtime:
|
233
272
|
def __init__(
|
234
273
|
self,
|
235
|
-
|
274
|
+
log_level: str = "error",
|
275
|
+
model_overide_args: Optional[dict] = None,
|
236
276
|
*args,
|
237
277
|
**kwargs,
|
238
278
|
):
|
239
279
|
"""See the arguments in server_args.py::ServerArgs"""
|
240
|
-
self.server_args = ServerArgs(*args, log_level=
|
280
|
+
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
241
281
|
|
242
282
|
# Pre-allocate ports
|
243
283
|
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
244
|
-
self.server_args.port,
|
284
|
+
self.server_args.port,
|
285
|
+
self.server_args.additional_ports,
|
286
|
+
self.server_args.tp_size,
|
287
|
+
self.server_args.dp_size,
|
288
|
+
)
|
245
289
|
|
246
290
|
self.url = self.server_args.url()
|
247
291
|
self.generate_url = (
|
@@ -250,7 +294,10 @@ class Runtime:
|
|
250
294
|
|
251
295
|
self.pid = None
|
252
296
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
253
|
-
proc = mp.Process(
|
297
|
+
proc = mp.Process(
|
298
|
+
target=launch_server,
|
299
|
+
args=(self.server_args, pipe_writer, model_overide_args),
|
300
|
+
)
|
254
301
|
proc.start()
|
255
302
|
pipe_writer.close()
|
256
303
|
self.pid = proc.pid
|
@@ -262,7 +309,9 @@ class Runtime:
|
|
262
309
|
|
263
310
|
if init_state != "init ok":
|
264
311
|
self.shutdown()
|
265
|
-
raise RuntimeError(
|
312
|
+
raise RuntimeError(
|
313
|
+
"Initialization failed. Please see the error messages above."
|
314
|
+
)
|
266
315
|
|
267
316
|
self.endpoint = RuntimeEndpoint(self.url)
|
268
317
|
|
@@ -314,4 +363,4 @@ class Runtime:
|
|
314
363
|
pos += len(cur)
|
315
364
|
|
316
365
|
def __del__(self):
|
317
|
-
self.shutdown()
|
366
|
+
self.shutdown()
|
sglang/srt/server_args.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
import argparse
|
4
4
|
import dataclasses
|
5
|
+
import random
|
5
6
|
from typing import List, Optional, Union
|
6
7
|
|
7
8
|
|
@@ -15,6 +16,7 @@ class ServerArgs:
|
|
15
16
|
chat_template: Optional[str] = None
|
16
17
|
trust_remote_code: bool = True
|
17
18
|
context_length: Optional[int] = None
|
19
|
+
quantization: Optional[str] = None
|
18
20
|
|
19
21
|
# Port
|
20
22
|
host: str = "127.0.0.1"
|
@@ -23,14 +25,15 @@ class ServerArgs:
|
|
23
25
|
|
24
26
|
# Memory and scheduling
|
25
27
|
mem_fraction_static: Optional[float] = None
|
26
|
-
|
28
|
+
max_prefill_tokens: Optional[int] = None
|
29
|
+
max_running_requests: Optional[int] = None
|
27
30
|
schedule_heuristic: str = "lpm"
|
28
31
|
schedule_conservativeness: float = 1.0
|
29
32
|
|
30
33
|
# Other runtime options
|
31
34
|
tp_size: int = 1
|
32
35
|
stream_interval: int = 8
|
33
|
-
random_seed: int =
|
36
|
+
random_seed: Optional[int] = None
|
34
37
|
|
35
38
|
# Logging
|
36
39
|
log_level: str = "info"
|
@@ -42,6 +45,10 @@ class ServerArgs:
|
|
42
45
|
# Other
|
43
46
|
api_key: str = ""
|
44
47
|
|
48
|
+
# Data parallelism
|
49
|
+
dp_size: int = 1
|
50
|
+
load_balance_method: str = "round_robin"
|
51
|
+
|
45
52
|
# Optimization/debug options
|
46
53
|
enable_flashinfer: bool = False
|
47
54
|
attention_reduce_in_fp32: bool = False
|
@@ -66,6 +73,9 @@ class ServerArgs:
|
|
66
73
|
elif self.additional_ports is None:
|
67
74
|
self.additional_ports = []
|
68
75
|
|
76
|
+
if self.random_seed is None:
|
77
|
+
self.random_seed = random.randint(0, 1 << 30)
|
78
|
+
|
69
79
|
@staticmethod
|
70
80
|
def add_cli_args(parser: argparse.ArgumentParser):
|
71
81
|
parser.add_argument(
|
@@ -80,10 +90,12 @@ class ServerArgs:
|
|
80
90
|
default=ServerArgs.tokenizer_path,
|
81
91
|
help="The path of the tokenizer.",
|
82
92
|
)
|
83
|
-
parser.add_argument(
|
84
|
-
|
85
|
-
|
86
|
-
|
93
|
+
parser.add_argument(
|
94
|
+
"--host", type=str, default=ServerArgs.host, help="The host of the server."
|
95
|
+
)
|
96
|
+
parser.add_argument(
|
97
|
+
"--port", type=int, default=ServerArgs.port, help="The port of the server."
|
98
|
+
)
|
87
99
|
parser.add_argument(
|
88
100
|
"--additional-ports",
|
89
101
|
type=int,
|
@@ -133,6 +145,12 @@ class ServerArgs:
|
|
133
145
|
default=ServerArgs.context_length,
|
134
146
|
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
135
147
|
)
|
148
|
+
parser.add_argument(
|
149
|
+
"--quantization",
|
150
|
+
type=str,
|
151
|
+
default=ServerArgs.quantization,
|
152
|
+
help="The quantization method.",
|
153
|
+
)
|
136
154
|
parser.add_argument(
|
137
155
|
"--mem-fraction-static",
|
138
156
|
type=float,
|
@@ -140,16 +158,23 @@ class ServerArgs:
|
|
140
158
|
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
|
141
159
|
)
|
142
160
|
parser.add_argument(
|
143
|
-
"--max-prefill-
|
161
|
+
"--max-prefill-tokens",
|
144
162
|
type=int,
|
145
|
-
default=ServerArgs.
|
163
|
+
default=ServerArgs.max_prefill_tokens,
|
146
164
|
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
147
165
|
)
|
166
|
+
parser.add_argument(
|
167
|
+
"--max-running-requests",
|
168
|
+
type=int,
|
169
|
+
default=ServerArgs.max_running_requests,
|
170
|
+
help="The maximum number of running requests.",
|
171
|
+
)
|
148
172
|
parser.add_argument(
|
149
173
|
"--schedule-heuristic",
|
150
174
|
type=str,
|
151
175
|
default=ServerArgs.schedule_heuristic,
|
152
|
-
|
176
|
+
choices=["lpm", "random", "fcfs", "dfs-weight"],
|
177
|
+
help="Scheduling Heuristic.",
|
153
178
|
)
|
154
179
|
parser.add_argument(
|
155
180
|
"--schedule-conservativeness",
|
@@ -209,6 +234,24 @@ class ServerArgs:
|
|
209
234
|
help="Set API key of the server",
|
210
235
|
)
|
211
236
|
|
237
|
+
# Data parallelism
|
238
|
+
parser.add_argument(
|
239
|
+
"--dp-size",
|
240
|
+
type=int,
|
241
|
+
default=ServerArgs.dp_size,
|
242
|
+
help="Data parallelism size.",
|
243
|
+
)
|
244
|
+
parser.add_argument(
|
245
|
+
"--load-balance-method",
|
246
|
+
type=str,
|
247
|
+
default=ServerArgs.load_balance_method,
|
248
|
+
help="Load balancing strategy for data parallelism.",
|
249
|
+
choices=[
|
250
|
+
"round_robin",
|
251
|
+
"shortest_queue",
|
252
|
+
],
|
253
|
+
)
|
254
|
+
|
212
255
|
# Optimization/debug options
|
213
256
|
parser.add_argument(
|
214
257
|
"--enable-flashinfer",
|
@@ -254,10 +297,15 @@ class ServerArgs:
|
|
254
297
|
)
|
255
298
|
|
256
299
|
|
300
|
+
@dataclasses.dataclass
|
301
|
+
class ModelPortArgs:
|
302
|
+
nccl_port: int
|
303
|
+
model_tp_ports: List[int]
|
304
|
+
|
305
|
+
|
257
306
|
@dataclasses.dataclass
|
258
307
|
class PortArgs:
|
259
308
|
tokenizer_port: int
|
260
309
|
router_port: int
|
261
310
|
detokenizer_port: int
|
262
|
-
|
263
|
-
model_rpc_ports: List[int]
|
311
|
+
model_port_args: List[ModelPortArgs]
|