sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
import argparse
|
4
4
|
import dataclasses
|
5
|
+
import random
|
5
6
|
from typing import List, Optional, Union
|
6
7
|
|
7
8
|
|
@@ -15,6 +16,7 @@ class ServerArgs:
|
|
15
16
|
chat_template: Optional[str] = None
|
16
17
|
trust_remote_code: bool = True
|
17
18
|
context_length: Optional[int] = None
|
19
|
+
quantization: Optional[str] = None
|
18
20
|
|
19
21
|
# Port
|
20
22
|
host: str = "127.0.0.1"
|
@@ -23,14 +25,15 @@ class ServerArgs:
|
|
23
25
|
|
24
26
|
# Memory and scheduling
|
25
27
|
mem_fraction_static: Optional[float] = None
|
26
|
-
|
28
|
+
max_prefill_tokens: Optional[int] = None
|
29
|
+
max_running_requests: Optional[int] = None
|
27
30
|
schedule_heuristic: str = "lpm"
|
28
31
|
schedule_conservativeness: float = 1.0
|
29
32
|
|
30
33
|
# Other runtime options
|
31
34
|
tp_size: int = 1
|
32
35
|
stream_interval: int = 8
|
33
|
-
random_seed: int =
|
36
|
+
random_seed: Optional[int] = None
|
34
37
|
|
35
38
|
# Logging
|
36
39
|
log_level: str = "info"
|
@@ -42,6 +45,10 @@ class ServerArgs:
|
|
42
45
|
# Other
|
43
46
|
api_key: str = ""
|
44
47
|
|
48
|
+
# Data parallelism
|
49
|
+
dp_size: int = 1
|
50
|
+
load_balance_method: str = "round_robin"
|
51
|
+
|
45
52
|
# Optimization/debug options
|
46
53
|
enable_flashinfer: bool = False
|
47
54
|
attention_reduce_in_fp32: bool = False
|
@@ -66,6 +73,9 @@ class ServerArgs:
|
|
66
73
|
elif self.additional_ports is None:
|
67
74
|
self.additional_ports = []
|
68
75
|
|
76
|
+
if self.random_seed is None:
|
77
|
+
self.random_seed = random.randint(0, 1 << 30)
|
78
|
+
|
69
79
|
@staticmethod
|
70
80
|
def add_cli_args(parser: argparse.ArgumentParser):
|
71
81
|
parser.add_argument(
|
@@ -135,6 +145,12 @@ class ServerArgs:
|
|
135
145
|
default=ServerArgs.context_length,
|
136
146
|
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
137
147
|
)
|
148
|
+
parser.add_argument(
|
149
|
+
"--quantization",
|
150
|
+
type=str,
|
151
|
+
default=ServerArgs.quantization,
|
152
|
+
help="The quantization method.",
|
153
|
+
)
|
138
154
|
parser.add_argument(
|
139
155
|
"--mem-fraction-static",
|
140
156
|
type=float,
|
@@ -142,11 +158,17 @@ class ServerArgs:
|
|
142
158
|
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
|
143
159
|
)
|
144
160
|
parser.add_argument(
|
145
|
-
"--max-prefill-
|
161
|
+
"--max-prefill-tokens",
|
146
162
|
type=int,
|
147
|
-
default=ServerArgs.
|
163
|
+
default=ServerArgs.max_prefill_tokens,
|
148
164
|
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
149
165
|
)
|
166
|
+
parser.add_argument(
|
167
|
+
"--max-running-requests",
|
168
|
+
type=int,
|
169
|
+
default=ServerArgs.max_running_requests,
|
170
|
+
help="The maximum number of running requests.",
|
171
|
+
)
|
150
172
|
parser.add_argument(
|
151
173
|
"--schedule-heuristic",
|
152
174
|
type=str,
|
@@ -212,6 +234,24 @@ class ServerArgs:
|
|
212
234
|
help="Set API key of the server",
|
213
235
|
)
|
214
236
|
|
237
|
+
# Data parallelism
|
238
|
+
parser.add_argument(
|
239
|
+
"--dp-size",
|
240
|
+
type=int,
|
241
|
+
default=ServerArgs.dp_size,
|
242
|
+
help="Data parallelism size.",
|
243
|
+
)
|
244
|
+
parser.add_argument(
|
245
|
+
"--load-balance-method",
|
246
|
+
type=str,
|
247
|
+
default=ServerArgs.load_balance_method,
|
248
|
+
help="Load balancing strategy for data parallelism.",
|
249
|
+
choices=[
|
250
|
+
"round_robin",
|
251
|
+
"shortest_queue",
|
252
|
+
],
|
253
|
+
)
|
254
|
+
|
215
255
|
# Optimization/debug options
|
216
256
|
parser.add_argument(
|
217
257
|
"--enable-flashinfer",
|
@@ -257,10 +297,15 @@ class ServerArgs:
|
|
257
297
|
)
|
258
298
|
|
259
299
|
|
300
|
+
@dataclasses.dataclass
|
301
|
+
class ModelPortArgs:
|
302
|
+
nccl_port: int
|
303
|
+
model_tp_ports: List[int]
|
304
|
+
|
305
|
+
|
260
306
|
@dataclasses.dataclass
|
261
307
|
class PortArgs:
|
262
308
|
tokenizer_port: int
|
263
309
|
router_port: int
|
264
310
|
detokenizer_port: int
|
265
|
-
|
266
|
-
model_rpc_ports: List[int]
|
311
|
+
model_port_args: List[ModelPortArgs]
|
sglang/srt/utils.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
"""Common utilities."""
|
2
2
|
|
3
3
|
import base64
|
4
|
+
import multiprocessing
|
5
|
+
import logging
|
4
6
|
import os
|
5
7
|
import random
|
6
8
|
import socket
|
@@ -10,15 +12,19 @@ from io import BytesIO
|
|
10
12
|
from typing import List, Optional
|
11
13
|
|
12
14
|
import numpy as np
|
13
|
-
import
|
15
|
+
import psutil
|
14
16
|
import requests
|
17
|
+
import rpyc
|
15
18
|
import torch
|
19
|
+
import triton
|
20
|
+
from rpyc.utils.server import ThreadedServer
|
16
21
|
from fastapi.responses import JSONResponse
|
17
22
|
from packaging import version as pkg_version
|
18
|
-
from pydantic import BaseModel
|
19
23
|
from starlette.middleware.base import BaseHTTPMiddleware
|
20
24
|
|
21
|
-
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
22
28
|
|
23
29
|
show_time_cost = False
|
24
30
|
time_infos = {}
|
@@ -90,7 +96,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|
90
96
|
return wrapper
|
91
97
|
|
92
98
|
|
93
|
-
def get_available_gpu_memory(gpu_id, distributed=
|
99
|
+
def get_available_gpu_memory(gpu_id, distributed=False):
|
94
100
|
"""
|
95
101
|
Get available memory for cuda:gpu_id device.
|
96
102
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
@@ -104,6 +110,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
|
|
104
110
|
"which may cause useless memory allocation for torch CUDA context.",
|
105
111
|
)
|
106
112
|
|
113
|
+
torch.cuda.empty_cache()
|
107
114
|
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
108
115
|
|
109
116
|
if distributed:
|
@@ -117,38 +124,21 @@ def get_available_gpu_memory(gpu_id, distributed=True):
|
|
117
124
|
|
118
125
|
|
119
126
|
def set_random_seed(seed: int) -> None:
|
127
|
+
"""Set the random seed for all libraries."""
|
120
128
|
random.seed(seed)
|
121
|
-
|
129
|
+
np.random.seed(seed)
|
122
130
|
torch.manual_seed(seed)
|
123
131
|
if torch.cuda.is_available():
|
124
132
|
torch.cuda.manual_seed_all(seed)
|
125
133
|
|
126
134
|
|
127
|
-
def
|
128
|
-
|
129
|
-
for port in range(10000, 65536):
|
130
|
-
if port in used_list:
|
131
|
-
continue
|
132
|
-
|
133
|
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
134
|
-
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
135
|
-
try:
|
136
|
-
s.bind(("", port))
|
137
|
-
s.listen(1) # Attempt to listen on the port
|
138
|
-
port_list.append(port)
|
139
|
-
except socket.error:
|
140
|
-
pass # If any error occurs, this port is not usable
|
141
|
-
|
142
|
-
if len(port_list) == num:
|
143
|
-
return port_list
|
144
|
-
return None
|
145
|
-
|
146
|
-
|
147
|
-
def check_port(port):
|
135
|
+
def is_port_available(port):
|
136
|
+
"""Return whether a port is available."""
|
148
137
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
149
138
|
try:
|
150
139
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
151
140
|
s.bind(("", port))
|
141
|
+
s.listen(1)
|
152
142
|
return True
|
153
143
|
except socket.error:
|
154
144
|
return False
|
@@ -158,35 +148,34 @@ def allocate_init_ports(
|
|
158
148
|
port: Optional[int] = None,
|
159
149
|
additional_ports: Optional[List[int]] = None,
|
160
150
|
tp_size: int = 1,
|
151
|
+
dp_size: int = 1,
|
161
152
|
):
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
|
153
|
+
"""Allocate ports for all connections."""
|
154
|
+
if additional_ports:
|
155
|
+
ret_ports = [port] + additional_ports
|
156
|
+
else:
|
157
|
+
ret_ports = [port]
|
158
|
+
|
159
|
+
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
160
|
+
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
161
|
+
|
162
|
+
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
163
|
+
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
164
|
+
while len(ret_ports) < num_ports_needed:
|
165
|
+
if cur_port not in ret_ports and is_port_available(cur_port):
|
166
|
+
ret_ports.append(cur_port)
|
167
|
+
cur_port += 1
|
168
|
+
|
169
|
+
if port is not None and ret_ports[0] != port:
|
170
|
+
logger.warn(
|
171
|
+
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
182
172
|
)
|
183
|
-
can_use_ports.extend(addtional_can_use_ports)
|
184
173
|
|
185
|
-
|
186
|
-
return port, additional_ports
|
174
|
+
return ret_ports[0], ret_ports[1:num_ports_needed]
|
187
175
|
|
188
176
|
|
189
177
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
178
|
+
"""Get the logit bias for integer-only tokens."""
|
190
179
|
# a bug when model's vocab size > tokenizer.vocab_size
|
191
180
|
vocab_size = tokenizer.vocab_size
|
192
181
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
@@ -200,14 +189,11 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
|
200
189
|
|
201
190
|
def wrap_kernel_launcher(kernel):
|
202
191
|
"""A faster launcher for triton kernels."""
|
203
|
-
|
204
|
-
|
205
|
-
if dist.is_initialized():
|
206
|
-
rank = dist.get_rank()
|
207
|
-
else:
|
208
|
-
rank = 0
|
192
|
+
if int(triton.__version__.split(".")[0]) >= 3:
|
193
|
+
return None
|
209
194
|
|
210
|
-
|
195
|
+
gpu_id = torch.cuda.current_device()
|
196
|
+
kernels = kernel.cache[gpu_id].values()
|
211
197
|
kernel = next(iter(kernels))
|
212
198
|
|
213
199
|
# Different trition versions use different low-level names
|
@@ -275,7 +261,9 @@ def is_multimodal_model(model):
|
|
275
261
|
|
276
262
|
if isinstance(model, ModelConfig):
|
277
263
|
model_path = model.path.lower()
|
278
|
-
return
|
264
|
+
return (
|
265
|
+
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
266
|
+
)
|
279
267
|
|
280
268
|
raise ValueError("unrecognized type")
|
281
269
|
|
@@ -382,6 +370,64 @@ def load_image(image_file):
|
|
382
370
|
return image, image_size
|
383
371
|
|
384
372
|
|
373
|
+
def init_rpyc_service(service: rpyc.Service, port: int):
|
374
|
+
t = ThreadedServer(
|
375
|
+
service=service,
|
376
|
+
port=port,
|
377
|
+
protocol_config={
|
378
|
+
"allow_public_attrs": True,
|
379
|
+
"allow_pickle": True,
|
380
|
+
"sync_request_timeout": 3600
|
381
|
+
},
|
382
|
+
)
|
383
|
+
t.logger.setLevel(logging.WARN)
|
384
|
+
t.start()
|
385
|
+
|
386
|
+
|
387
|
+
def connect_to_rpyc_service(port, host="localhost"):
|
388
|
+
time.sleep(1)
|
389
|
+
|
390
|
+
repeat_count = 0
|
391
|
+
while repeat_count < 20:
|
392
|
+
try:
|
393
|
+
con = rpyc.connect(
|
394
|
+
host,
|
395
|
+
port,
|
396
|
+
config={
|
397
|
+
"allow_public_attrs": True,
|
398
|
+
"allow_pickle": True,
|
399
|
+
"sync_request_timeout": 3600
|
400
|
+
},
|
401
|
+
)
|
402
|
+
break
|
403
|
+
except ConnectionRefusedError:
|
404
|
+
time.sleep(1)
|
405
|
+
repeat_count += 1
|
406
|
+
if repeat_count == 20:
|
407
|
+
raise RuntimeError("init rpc env error!")
|
408
|
+
|
409
|
+
return con.root
|
410
|
+
|
411
|
+
|
412
|
+
def start_rpyc_process(service: rpyc.Service, port: int):
|
413
|
+
# Return the proxy and the process
|
414
|
+
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
|
415
|
+
proc.start()
|
416
|
+
proxy = connect_to_rpyc_service(port)
|
417
|
+
assert proc.is_alive()
|
418
|
+
return proxy, proc
|
419
|
+
|
420
|
+
|
421
|
+
def suppress_other_loggers():
|
422
|
+
from vllm.logger import logger as vllm_default_logger
|
423
|
+
|
424
|
+
vllm_default_logger.setLevel(logging.WARN)
|
425
|
+
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
426
|
+
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN)
|
427
|
+
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
428
|
+
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
429
|
+
|
430
|
+
|
385
431
|
def assert_pkg_version(pkg: str, min_version: str):
|
386
432
|
try:
|
387
433
|
installed_version = version(pkg)
|
@@ -396,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str):
|
|
396
442
|
)
|
397
443
|
|
398
444
|
|
445
|
+
def kill_parent_process():
|
446
|
+
"""Kill the parent process and all children of the parent process."""
|
447
|
+
current_process = psutil.Process()
|
448
|
+
parent_process = current_process.parent()
|
449
|
+
children = current_process.children(recursive=True)
|
450
|
+
for child in children:
|
451
|
+
if child.pid != current_process.pid:
|
452
|
+
os.kill(child.pid, 9)
|
453
|
+
os.kill(parent_process.pid, 9)
|
454
|
+
|
455
|
+
|
456
|
+
def monkey_patch_vllm_p2p_access_check():
|
457
|
+
"""
|
458
|
+
Monkey patch the slow p2p access check in vllm.
|
459
|
+
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
460
|
+
"""
|
461
|
+
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
462
|
+
|
463
|
+
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
464
|
+
|
465
|
+
|
399
466
|
API_KEY_HEADER_NAME = "X-API-Key"
|
400
467
|
|
401
468
|
|
@@ -415,12 +482,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|
415
482
|
response = await call_next(request)
|
416
483
|
return response
|
417
484
|
|
418
|
-
|
419
|
-
# FIXME: Remove this once we drop support for pydantic 1.x
|
420
|
-
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
421
|
-
|
422
|
-
|
423
|
-
def jsonify_pydantic_model(obj: BaseModel):
|
424
|
-
if IS_PYDANTIC_1:
|
425
|
-
return obj.json(ensure_ascii=False)
|
426
|
-
return obj.model_dump_json()
|
sglang/test/test_programs.py
CHANGED
@@ -304,6 +304,7 @@ def test_image_qa():
|
|
304
304
|
temperature=0,
|
305
305
|
max_new_tokens=64,
|
306
306
|
)
|
307
|
+
|
307
308
|
assert (
|
308
309
|
"taxi" in state.messages()[-1]["content"]
|
309
310
|
or "car" in state.messages()[-1]["content"]
|
@@ -349,3 +350,46 @@ def test_regex():
|
|
349
350
|
state = regex_gen.run()
|
350
351
|
answer = state["answer"]
|
351
352
|
assert re.match(regex, answer)
|
353
|
+
|
354
|
+
|
355
|
+
def test_completion_speculative():
|
356
|
+
@sgl.function(num_api_spec_tokens=64)
|
357
|
+
def gen_character_spec(s):
|
358
|
+
s += "Construct a character within the following format:\n"
|
359
|
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
360
|
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
361
|
+
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
362
|
+
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
363
|
+
|
364
|
+
|
365
|
+
@sgl.function
|
366
|
+
def gen_character_no_spec(s):
|
367
|
+
s += "Construct a character within the following format:\n"
|
368
|
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
369
|
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
370
|
+
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
371
|
+
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
372
|
+
|
373
|
+
token_usage = sgl.global_config.default_backend.token_usage
|
374
|
+
|
375
|
+
token_usage.reset()
|
376
|
+
gen_character_spec().sync()
|
377
|
+
usage_with_spec = token_usage.prompt_tokens
|
378
|
+
|
379
|
+
token_usage.reset()
|
380
|
+
gen_character_no_spec().sync()
|
381
|
+
usage_with_no_spec = token_usage.prompt_tokens
|
382
|
+
|
383
|
+
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
|
384
|
+
|
385
|
+
|
386
|
+
def test_chat_completion_speculative():
|
387
|
+
@sgl.function(num_api_spec_tokens=256)
|
388
|
+
def gen_character_spec(s):
|
389
|
+
s += sgl.system("You are a helpful assistant.")
|
390
|
+
s += sgl.user("Construct a character within the following format:")
|
391
|
+
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
|
392
|
+
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
393
|
+
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
394
|
+
|
395
|
+
gen_character_spec().sync()
|
sglang/test/test_utils.py
CHANGED
@@ -9,7 +9,7 @@ import requests
|
|
9
9
|
from sglang.backend.openai import OpenAI
|
10
10
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
11
11
|
from sglang.global_config import global_config
|
12
|
-
from sglang.
|
12
|
+
from sglang.utils import get_exception_traceback
|
13
13
|
|
14
14
|
|
15
15
|
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
@@ -88,6 +88,33 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
|
|
88
88
|
return pred
|
89
89
|
|
90
90
|
|
91
|
+
def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
|
92
|
+
import grpc
|
93
|
+
from ginfer import sampler_pb2, sampler_pb2_grpc
|
94
|
+
|
95
|
+
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
|
96
|
+
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
|
97
|
+
|
98
|
+
if stop is None:
|
99
|
+
stop_strings = None
|
100
|
+
else:
|
101
|
+
stop_strings = [stop]
|
102
|
+
|
103
|
+
sample_request = sampler_pb2.SampleTextRequest(
|
104
|
+
prompt=prompt,
|
105
|
+
settings=sampler_pb2.SampleSettings(
|
106
|
+
max_len=max_tokens,
|
107
|
+
rng_seed=0,
|
108
|
+
temperature=max(temperature, 1e-7),
|
109
|
+
nucleus_p=1,
|
110
|
+
stop_strings=stop_strings,
|
111
|
+
),
|
112
|
+
)
|
113
|
+
stream = sampler.SampleText(sample_request)
|
114
|
+
response = "".join([x.text for x in stream])
|
115
|
+
return response
|
116
|
+
|
117
|
+
|
91
118
|
def call_generate_guidance(
|
92
119
|
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
|
93
120
|
):
|
@@ -228,6 +255,7 @@ def add_common_other_args_and_parse(parser):
|
|
228
255
|
"vllm",
|
229
256
|
"outlines",
|
230
257
|
"lightllm",
|
258
|
+
"ginfer",
|
231
259
|
"guidance",
|
232
260
|
"lmql",
|
233
261
|
"srt-raw",
|
@@ -248,6 +276,7 @@ def add_common_other_args_and_parse(parser):
|
|
248
276
|
"lightllm": 22000,
|
249
277
|
"lmql": 23000,
|
250
278
|
"srt-raw": 30000,
|
279
|
+
"ginfer": 9988,
|
251
280
|
}
|
252
281
|
args.port = default_port.get(args.backend, None)
|
253
282
|
return args
|
@@ -283,6 +312,8 @@ def _get_call_generate(args):
|
|
283
312
|
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
|
284
313
|
elif args.backend == "srt-raw":
|
285
314
|
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
|
315
|
+
elif args.backend == "ginfer":
|
316
|
+
return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
|
286
317
|
elif args.backend == "outlines":
|
287
318
|
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
|
288
319
|
elif args.backend == "guidance":
|
sglang/utils.py
CHANGED
@@ -2,7 +2,8 @@
|
|
2
2
|
|
3
3
|
import base64
|
4
4
|
import json
|
5
|
-
import
|
5
|
+
import logging
|
6
|
+
import signal
|
6
7
|
import sys
|
7
8
|
import threading
|
8
9
|
import traceback
|
@@ -15,6 +16,9 @@ import numpy as np
|
|
15
16
|
import requests
|
16
17
|
|
17
18
|
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
18
22
|
def get_exception_traceback():
|
19
23
|
etype, value, tb = sys.exc_info()
|
20
24
|
err_str = "".join(traceback.format_exception(etype, value, tb))
|
@@ -93,8 +97,12 @@ def http_request(
|
|
93
97
|
data = None
|
94
98
|
else:
|
95
99
|
data = bytes(dumps(json), encoding="utf-8")
|
96
|
-
|
97
|
-
|
100
|
+
|
101
|
+
try:
|
102
|
+
resp = urllib.request.urlopen(req, data=data, cafile=verify)
|
103
|
+
return HttpResponse(resp)
|
104
|
+
except urllib.error.HTTPError as e:
|
105
|
+
return HttpResponse(e)
|
98
106
|
|
99
107
|
|
100
108
|
def encode_image_base64(image_path):
|
@@ -137,7 +145,8 @@ def encode_frame(frame):
|
|
137
145
|
|
138
146
|
|
139
147
|
def encode_video_base64(video_path, num_frames=16):
|
140
|
-
import cv2
|
148
|
+
import cv2 # pip install opencv-python-headless
|
149
|
+
|
141
150
|
cap = cv2.VideoCapture(video_path)
|
142
151
|
if not cap.isOpened():
|
143
152
|
raise IOError(f"Could not open video file:{video_path}")
|
@@ -242,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
|
242
251
|
raise RuntimeError()
|
243
252
|
|
244
253
|
return ret_value[0]
|
254
|
+
|
255
|
+
|
256
|
+
def graceful_registry(sub_module_name):
|
257
|
+
def graceful_shutdown(signum, frame):
|
258
|
+
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
|
259
|
+
if signum == signal.SIGTERM:
|
260
|
+
logger.info(f"{sub_module_name} recive sigterm")
|
261
|
+
|
262
|
+
signal.signal(signal.SIGTERM, graceful_shutdown)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.17
|
4
4
|
Summary: A structured generation langauge for LLMs.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -217,9 +217,12 @@ Provides-Extra: all
|
|
217
217
|
Requires-Dist: sglang[srt] ; extra == 'all'
|
218
218
|
Requires-Dist: sglang[openai] ; extra == 'all'
|
219
219
|
Requires-Dist: sglang[anthropic] ; extra == 'all'
|
220
|
+
Requires-Dist: sglang[litellm] ; extra == 'all'
|
220
221
|
Provides-Extra: anthropic
|
221
222
|
Requires-Dist: anthropic >=0.20.0 ; extra == 'anthropic'
|
222
223
|
Requires-Dist: numpy ; extra == 'anthropic'
|
224
|
+
Provides-Extra: litellm
|
225
|
+
Requires-Dist: litellm >=1.0.0 ; extra == 'litellm'
|
223
226
|
Provides-Extra: openai
|
224
227
|
Requires-Dist: openai >=1.0 ; extra == 'openai'
|
225
228
|
Requires-Dist: numpy ; extra == 'openai'
|
@@ -233,7 +236,7 @@ Requires-Dist: torch ; extra == 'srt'
|
|
233
236
|
Requires-Dist: uvloop ; extra == 'srt'
|
234
237
|
Requires-Dist: uvicorn ; extra == 'srt'
|
235
238
|
Requires-Dist: zmq ; extra == 'srt'
|
236
|
-
Requires-Dist: vllm
|
239
|
+
Requires-Dist: vllm ==0.4.3 ; extra == 'srt'
|
237
240
|
Requires-Dist: interegular ; extra == 'srt'
|
238
241
|
Requires-Dist: pydantic ; extra == 'srt'
|
239
242
|
Requires-Dist: pillow ; extra == 'srt'
|
@@ -253,9 +256,9 @@ Requires-Dist: outlines >=0.0.34 ; extra == 'srt'
|
|
253
256
|
SGLang is a structured generation language designed for large language models (LLMs).
|
254
257
|
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
|
255
258
|
|
256
|
-
The core features
|
259
|
+
The core features include:
|
257
260
|
- **A Flexible Front-End Language**: This allows for easy programming of LLM applications with multiple chained generation calls, advanced prompting techniques, control flow, multiple modalities, parallelism, and external interaction.
|
258
|
-
- **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by
|
261
|
+
- **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by automatically reusing the KV cache across multiple calls. It can also be used as a standalone serving engine with all common techniques implemented, such as continuous batching and tensor parallelism.
|
259
262
|
|
260
263
|
## News
|
261
264
|
- [2024/02] 🔥 SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
|
@@ -288,12 +291,8 @@ pip install -e "python[all]"
|
|
288
291
|
```
|
289
292
|
|
290
293
|
### Notes
|
291
|
-
- If you are using older GPUs (NVIDIA V100, T4), please pick the correct triton compiler version to avoid some known bugs.
|
292
|
-
- For NVIDIA T4, please use `pip install "triton>=2.2.0"`.
|
293
|
-
- For NVIDIA V100, please install the [nightly](https://triton-lang.org/main/getting-started/installation.html) version.
|
294
294
|
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`
|
295
295
|
|
296
|
-
|
297
296
|
## Quick Start
|
298
297
|
The example below shows how to use sglang to answer a mulit-turn question.
|
299
298
|
|
@@ -603,11 +602,16 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|
603
602
|
```
|
604
603
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
|
605
604
|
```
|
605
|
+
- Add `--dp 2` to enable data parallelism. It can also be used together with tp. Data parallelism is better for throughput if there is enough memory.
|
606
|
+
```
|
607
|
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --dp 2 --tp 2
|
608
|
+
```
|
606
609
|
- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
|
607
610
|
```
|
608
611
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
|
609
612
|
```
|
610
|
-
-
|
613
|
+
- See [flashinfer.md](docs/flashinfer.md) on accelerating inference using highly optimized CUDA kernels.
|
614
|
+
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
|
611
615
|
|
612
616
|
### Supported Models
|
613
617
|
- Llama
|
@@ -621,6 +625,8 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|
621
625
|
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
622
626
|
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
623
627
|
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 3000`
|
628
|
+
- LLaVA-NeXT-Video
|
629
|
+
- see [srt_example_llava_v.sh](examples/usage/llava_video/srt_example_llava_v.sh)
|
624
630
|
- Yi-VL
|
625
631
|
- see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py).
|
626
632
|
- StableLM
|