sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -1,24 +1,31 @@
|
|
1
1
|
"""Common utilities."""
|
2
2
|
|
3
3
|
import base64
|
4
|
+
import fcntl
|
5
|
+
import logging
|
6
|
+
import multiprocessing
|
4
7
|
import os
|
5
8
|
import random
|
6
9
|
import socket
|
10
|
+
import struct
|
7
11
|
import time
|
8
12
|
from importlib.metadata import PackageNotFoundError, version
|
9
13
|
from io import BytesIO
|
10
14
|
from typing import List, Optional
|
11
15
|
|
12
16
|
import numpy as np
|
13
|
-
import
|
17
|
+
import psutil
|
14
18
|
import requests
|
19
|
+
import rpyc
|
15
20
|
import torch
|
21
|
+
import triton
|
16
22
|
from fastapi.responses import JSONResponse
|
17
23
|
from packaging import version as pkg_version
|
18
|
-
from
|
24
|
+
from rpyc.utils.server import ThreadedServer
|
19
25
|
from starlette.middleware.base import BaseHTTPMiddleware
|
20
26
|
|
21
|
-
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
22
29
|
|
23
30
|
show_time_cost = False
|
24
31
|
time_infos = {}
|
@@ -90,7 +97,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|
90
97
|
return wrapper
|
91
98
|
|
92
99
|
|
93
|
-
def get_available_gpu_memory(gpu_id, distributed=
|
100
|
+
def get_available_gpu_memory(gpu_id, distributed=False):
|
94
101
|
"""
|
95
102
|
Get available memory for cuda:gpu_id device.
|
96
103
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
@@ -104,6 +111,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
|
|
104
111
|
"which may cause useless memory allocation for torch CUDA context.",
|
105
112
|
)
|
106
113
|
|
114
|
+
torch.cuda.empty_cache()
|
107
115
|
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
108
116
|
|
109
117
|
if distributed:
|
@@ -117,38 +125,21 @@ def get_available_gpu_memory(gpu_id, distributed=True):
|
|
117
125
|
|
118
126
|
|
119
127
|
def set_random_seed(seed: int) -> None:
|
128
|
+
"""Set the random seed for all libraries."""
|
120
129
|
random.seed(seed)
|
121
|
-
|
130
|
+
np.random.seed(seed)
|
122
131
|
torch.manual_seed(seed)
|
123
132
|
if torch.cuda.is_available():
|
124
133
|
torch.cuda.manual_seed_all(seed)
|
125
134
|
|
126
135
|
|
127
|
-
def
|
128
|
-
|
129
|
-
for port in range(10000, 65536):
|
130
|
-
if port in used_list:
|
131
|
-
continue
|
132
|
-
|
133
|
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
134
|
-
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
135
|
-
try:
|
136
|
-
s.bind(("", port))
|
137
|
-
s.listen(1) # Attempt to listen on the port
|
138
|
-
port_list.append(port)
|
139
|
-
except socket.error:
|
140
|
-
pass # If any error occurs, this port is not usable
|
141
|
-
|
142
|
-
if len(port_list) == num:
|
143
|
-
return port_list
|
144
|
-
return None
|
145
|
-
|
146
|
-
|
147
|
-
def check_port(port):
|
136
|
+
def is_port_available(port):
|
137
|
+
"""Return whether a port is available."""
|
148
138
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
149
139
|
try:
|
150
140
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
151
141
|
s.bind(("", port))
|
142
|
+
s.listen(1)
|
152
143
|
return True
|
153
144
|
except socket.error:
|
154
145
|
return False
|
@@ -158,35 +149,34 @@ def allocate_init_ports(
|
|
158
149
|
port: Optional[int] = None,
|
159
150
|
additional_ports: Optional[List[int]] = None,
|
160
151
|
tp_size: int = 1,
|
152
|
+
dp_size: int = 1,
|
161
153
|
):
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
|
154
|
+
"""Allocate ports for all connections."""
|
155
|
+
if additional_ports:
|
156
|
+
ret_ports = [port] + additional_ports
|
157
|
+
else:
|
158
|
+
ret_ports = [port]
|
159
|
+
|
160
|
+
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
161
|
+
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
162
|
+
|
163
|
+
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
164
|
+
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
165
|
+
while len(ret_ports) < num_ports_needed:
|
166
|
+
if cur_port not in ret_ports and is_port_available(cur_port):
|
167
|
+
ret_ports.append(cur_port)
|
168
|
+
cur_port += 1
|
169
|
+
|
170
|
+
if port is not None and ret_ports[0] != port:
|
171
|
+
logger.warn(
|
172
|
+
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
182
173
|
)
|
183
|
-
can_use_ports.extend(addtional_can_use_ports)
|
184
174
|
|
185
|
-
|
186
|
-
return port, additional_ports
|
175
|
+
return ret_ports[0], ret_ports[1:num_ports_needed]
|
187
176
|
|
188
177
|
|
189
178
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
179
|
+
"""Get the logit bias for integer-only tokens."""
|
190
180
|
# a bug when model's vocab size > tokenizer.vocab_size
|
191
181
|
vocab_size = tokenizer.vocab_size
|
192
182
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
@@ -200,14 +190,11 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
|
200
190
|
|
201
191
|
def wrap_kernel_launcher(kernel):
|
202
192
|
"""A faster launcher for triton kernels."""
|
203
|
-
|
204
|
-
|
205
|
-
if dist.is_initialized():
|
206
|
-
rank = dist.get_rank()
|
207
|
-
else:
|
208
|
-
rank = 0
|
193
|
+
if int(triton.__version__.split(".")[0]) >= 3:
|
194
|
+
return None
|
209
195
|
|
210
|
-
|
196
|
+
gpu_id = torch.cuda.current_device()
|
197
|
+
kernels = kernel.cache[gpu_id].values()
|
211
198
|
kernel = next(iter(kernels))
|
212
199
|
|
213
200
|
# Different trition versions use different low-level names
|
@@ -275,7 +262,9 @@ def is_multimodal_model(model):
|
|
275
262
|
|
276
263
|
if isinstance(model, ModelConfig):
|
277
264
|
model_path = model.path.lower()
|
278
|
-
return
|
265
|
+
return (
|
266
|
+
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
267
|
+
)
|
279
268
|
|
280
269
|
raise ValueError("unrecognized type")
|
281
270
|
|
@@ -382,20 +371,145 @@ def load_image(image_file):
|
|
382
371
|
return image, image_size
|
383
372
|
|
384
373
|
|
385
|
-
def
|
374
|
+
def connect_rpyc_service(host, port):
|
375
|
+
repeat_count = 0
|
376
|
+
while repeat_count < 20:
|
377
|
+
try:
|
378
|
+
con = rpyc.connect(
|
379
|
+
host,
|
380
|
+
port,
|
381
|
+
config={
|
382
|
+
"allow_public_attrs": True,
|
383
|
+
"allow_pickle": True,
|
384
|
+
"sync_request_timeout": 3600,
|
385
|
+
},
|
386
|
+
)
|
387
|
+
break
|
388
|
+
except ConnectionRefusedError as e:
|
389
|
+
time.sleep(1)
|
390
|
+
repeat_count += 1
|
391
|
+
if repeat_count == 20:
|
392
|
+
raise RuntimeError(f"Connect rpyc error: {e}")
|
393
|
+
|
394
|
+
return con.root
|
395
|
+
|
396
|
+
|
397
|
+
def start_rpyc_service(service: rpyc.Service, port: int):
|
398
|
+
t = ThreadedServer(
|
399
|
+
service=service,
|
400
|
+
port=port,
|
401
|
+
protocol_config={
|
402
|
+
"allow_public_attrs": True,
|
403
|
+
"allow_pickle": True,
|
404
|
+
"sync_request_timeout": 3600,
|
405
|
+
},
|
406
|
+
)
|
407
|
+
t.logger.setLevel(logging.WARN)
|
408
|
+
t.start()
|
409
|
+
|
410
|
+
|
411
|
+
def start_rpyc_service_process(service: rpyc.Service, port: int):
|
412
|
+
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
|
413
|
+
proc.start()
|
414
|
+
return proc
|
415
|
+
|
416
|
+
|
417
|
+
def suppress_other_loggers():
|
418
|
+
from vllm.logger import logger as vllm_default_logger
|
419
|
+
|
420
|
+
vllm_default_logger.setLevel(logging.WARN)
|
421
|
+
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
422
|
+
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
|
423
|
+
logging.WARN
|
424
|
+
)
|
425
|
+
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
426
|
+
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
427
|
+
|
428
|
+
|
429
|
+
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
386
430
|
try:
|
387
431
|
installed_version = version(pkg)
|
388
432
|
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
|
389
433
|
raise Exception(
|
390
|
-
f"{pkg} is installed with version {installed_version} which "
|
391
|
-
f"is less than the minimum required version {min_version}"
|
434
|
+
f"{pkg} is installed with version {installed_version}, which "
|
435
|
+
f"is less than the minimum required version {min_version}. " +
|
436
|
+
message
|
392
437
|
)
|
393
438
|
except PackageNotFoundError:
|
394
439
|
raise Exception(
|
395
|
-
f"{pkg} with minimum required version {min_version} is not installed"
|
440
|
+
f"{pkg} with minimum required version {min_version} is not installed. " +
|
441
|
+
message
|
396
442
|
)
|
397
443
|
|
398
444
|
|
445
|
+
def kill_parent_process():
|
446
|
+
"""Kill the parent process and all children of the parent process."""
|
447
|
+
current_process = psutil.Process()
|
448
|
+
parent_process = current_process.parent()
|
449
|
+
children = current_process.children(recursive=True)
|
450
|
+
for child in children:
|
451
|
+
if child.pid != current_process.pid:
|
452
|
+
os.kill(child.pid, 9)
|
453
|
+
os.kill(parent_process.pid, 9)
|
454
|
+
|
455
|
+
|
456
|
+
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
457
|
+
"""
|
458
|
+
Monkey patch the slow p2p access check in vllm.
|
459
|
+
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
460
|
+
"""
|
461
|
+
|
462
|
+
# TODO: need a better check than just dev str name match
|
463
|
+
# compat: skip RTX 40 series as they do not have P2P feature and even checking for them may cause errors
|
464
|
+
device_name = torch.cuda.get_device_name(gpu_id)
|
465
|
+
if "RTX 40" not in device_name:
|
466
|
+
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
467
|
+
|
468
|
+
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
469
|
+
|
470
|
+
|
471
|
+
def monkey_patch_vllm_dummy_weight_loader():
|
472
|
+
"""
|
473
|
+
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
|
474
|
+
"""
|
475
|
+
|
476
|
+
from vllm.model_executor.model_loader.loader import (
|
477
|
+
ModelConfig, DeviceConfig, LoRAConfig, VisionLanguageConfig,
|
478
|
+
ParallelConfig, SchedulerConfig, CacheConfig, nn,
|
479
|
+
set_default_torch_dtype, _initialize_model, initialize_dummy_weights,
|
480
|
+
DummyModelLoader
|
481
|
+
)
|
482
|
+
|
483
|
+
def load_model(self, *, model_config: ModelConfig,
|
484
|
+
device_config: DeviceConfig,
|
485
|
+
lora_config: Optional[LoRAConfig],
|
486
|
+
vision_language_config: Optional[VisionLanguageConfig],
|
487
|
+
parallel_config: ParallelConfig,
|
488
|
+
scheduler_config: SchedulerConfig,
|
489
|
+
cache_config: CacheConfig) -> nn.Module:
|
490
|
+
with set_default_torch_dtype(model_config.dtype):
|
491
|
+
with torch.device(device_config.device):
|
492
|
+
model = _initialize_model(model_config, self.load_config,
|
493
|
+
lora_config, vision_language_config,
|
494
|
+
cache_config)
|
495
|
+
|
496
|
+
for _, module in model.named_modules():
|
497
|
+
quant_method = getattr(module, "quant_method", None)
|
498
|
+
if quant_method is not None:
|
499
|
+
quant_method.process_weights_after_loading(module)
|
500
|
+
# FIXME: Remove this after Mixtral is updated
|
501
|
+
# to use quant_method.
|
502
|
+
if hasattr(module, "process_weights_after_loading"):
|
503
|
+
module.process_weights_after_loading()
|
504
|
+
|
505
|
+
# NOTE(woosuk): For accurate performance evaluation, we assign
|
506
|
+
# random values to the weights.
|
507
|
+
initialize_dummy_weights(model)
|
508
|
+
return model.eval()
|
509
|
+
|
510
|
+
setattr(DummyModelLoader, "load_model", load_model)
|
511
|
+
|
512
|
+
|
399
513
|
API_KEY_HEADER_NAME = "X-API-Key"
|
400
514
|
|
401
515
|
|
@@ -416,11 +530,64 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|
416
530
|
return response
|
417
531
|
|
418
532
|
|
419
|
-
|
420
|
-
|
533
|
+
def get_ip_address(ifname):
|
534
|
+
"""
|
535
|
+
Get the IP address of a network interface.
|
536
|
+
|
537
|
+
:param ifname: Name of the network interface (e.g., 'eth0')
|
538
|
+
:return: IP address of the network interface
|
539
|
+
"""
|
540
|
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
541
|
+
ip_address = fcntl.ioctl(
|
542
|
+
s.fileno(),
|
543
|
+
0x8915, # SIOCGIFADDR
|
544
|
+
struct.pack('256s', bytes(ifname[:15], 'utf-8'))
|
545
|
+
)[20:24]
|
546
|
+
return socket.inet_ntoa(ip_address)
|
547
|
+
|
548
|
+
|
549
|
+
def send_addrs_to_rank_0(model_port_args, server_args):
|
550
|
+
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
551
|
+
import torch.distributed as dist
|
552
|
+
|
553
|
+
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
|
554
|
+
ip_addr = get_ip_address(ifname)
|
555
|
+
|
556
|
+
num_tp_ports = server_args.tp_size // server_args.nnodes
|
557
|
+
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
558
|
+
ip_addr = [int(x) for x in ip_addr.split(".")]
|
559
|
+
addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int)
|
560
|
+
|
561
|
+
init_method = f"tcp://{server_args.nccl_init_addr}"
|
562
|
+
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
|
563
|
+
dist.send(addrs_tensor, dst=0)
|
564
|
+
print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}")
|
565
|
+
|
566
|
+
dist.barrier()
|
567
|
+
dist.destroy_process_group()
|
568
|
+
|
569
|
+
|
570
|
+
def receive_addrs(model_port_args, server_args):
|
571
|
+
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
572
|
+
import torch.distributed as dist
|
573
|
+
|
574
|
+
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
|
575
|
+
ip_addr = get_ip_address(ifname)
|
576
|
+
|
577
|
+
num_tp_ports = server_args.tp_size // server_args.nnodes
|
578
|
+
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
579
|
+
|
580
|
+
init_method = f"tcp://{server_args.nccl_init_addr}"
|
581
|
+
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
|
421
582
|
|
583
|
+
for src_rank in range(1, server_args.nnodes):
|
584
|
+
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
|
585
|
+
dist.recv(tensor, src=src_rank)
|
586
|
+
ip = ".".join([str(x) for x in tensor[:4].tolist()])
|
587
|
+
ports = tensor[4:].tolist()
|
588
|
+
model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports
|
589
|
+
model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports
|
590
|
+
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
|
422
591
|
|
423
|
-
|
424
|
-
|
425
|
-
return obj.json(ensure_ascii=False)
|
426
|
-
return obj.model_dump_json()
|
592
|
+
dist.barrier()
|
593
|
+
dist.destroy_process_group()
|
sglang/test/test_programs.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
This file contains the SGL programs used for unit testing.
|
3
|
-
"""
|
1
|
+
"""This file contains the SGL programs used for unit testing."""
|
4
2
|
|
5
3
|
import json
|
6
4
|
import re
|
@@ -304,6 +302,7 @@ def test_image_qa():
|
|
304
302
|
temperature=0,
|
305
303
|
max_new_tokens=64,
|
306
304
|
)
|
305
|
+
|
307
306
|
assert (
|
308
307
|
"taxi" in state.messages()[-1]["content"]
|
309
308
|
or "car" in state.messages()[-1]["content"]
|
@@ -349,3 +348,66 @@ def test_regex():
|
|
349
348
|
state = regex_gen.run()
|
350
349
|
answer = state["answer"]
|
351
350
|
assert re.match(regex, answer)
|
351
|
+
|
352
|
+
|
353
|
+
def test_completion_speculative():
|
354
|
+
@sgl.function(num_api_spec_tokens=64)
|
355
|
+
def gen_character_spec(s):
|
356
|
+
s += "Construct a character within the following format:\n"
|
357
|
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
358
|
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
359
|
+
s += (
|
360
|
+
"Name:"
|
361
|
+
+ sgl.gen("name", stop="\n")
|
362
|
+
+ "\nBirthday:"
|
363
|
+
+ sgl.gen("birthday", stop="\n")
|
364
|
+
)
|
365
|
+
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
366
|
+
|
367
|
+
@sgl.function
|
368
|
+
def gen_character_no_spec(s):
|
369
|
+
s += "Construct a character within the following format:\n"
|
370
|
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
371
|
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
372
|
+
s += (
|
373
|
+
"Name:"
|
374
|
+
+ sgl.gen("name", stop="\n")
|
375
|
+
+ "\nBirthday:"
|
376
|
+
+ sgl.gen("birthday", stop="\n")
|
377
|
+
)
|
378
|
+
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
379
|
+
|
380
|
+
token_usage = sgl.global_config.default_backend.token_usage
|
381
|
+
|
382
|
+
token_usage.reset()
|
383
|
+
gen_character_spec().sync()
|
384
|
+
usage_with_spec = token_usage.prompt_tokens
|
385
|
+
|
386
|
+
token_usage.reset()
|
387
|
+
gen_character_no_spec().sync()
|
388
|
+
usage_with_no_spec = token_usage.prompt_tokens
|
389
|
+
|
390
|
+
assert (
|
391
|
+
usage_with_spec < usage_with_no_spec
|
392
|
+
), f"{usage_with_spec} vs {usage_with_no_spec}"
|
393
|
+
|
394
|
+
|
395
|
+
def test_chat_completion_speculative():
|
396
|
+
@sgl.function(num_api_spec_tokens=256)
|
397
|
+
def gen_character_spec(s):
|
398
|
+
s += sgl.system("You are a helpful assistant.")
|
399
|
+
s += sgl.user("Construct a character within the following format:")
|
400
|
+
s += sgl.assistant(
|
401
|
+
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
402
|
+
)
|
403
|
+
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
404
|
+
s += sgl.assistant(
|
405
|
+
"Name:"
|
406
|
+
+ sgl.gen("name", stop="\n")
|
407
|
+
+ "\nBirthday:"
|
408
|
+
+ sgl.gen("birthday", stop="\n")
|
409
|
+
+ "\nJob:"
|
410
|
+
+ sgl.gen("job", stop="\n")
|
411
|
+
)
|
412
|
+
|
413
|
+
gen_character_spec().sync()
|
sglang/test/test_utils.py
CHANGED
@@ -9,7 +9,7 @@ import requests
|
|
9
9
|
from sglang.backend.openai import OpenAI
|
10
10
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
11
11
|
from sglang.global_config import global_config
|
12
|
-
from sglang.
|
12
|
+
from sglang.utils import get_exception_traceback
|
13
13
|
|
14
14
|
|
15
15
|
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
@@ -88,6 +88,33 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
|
|
88
88
|
return pred
|
89
89
|
|
90
90
|
|
91
|
+
def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
|
92
|
+
import grpc
|
93
|
+
from ginfer import sampler_pb2, sampler_pb2_grpc
|
94
|
+
|
95
|
+
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
|
96
|
+
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
|
97
|
+
|
98
|
+
if stop is None:
|
99
|
+
stop_strings = None
|
100
|
+
else:
|
101
|
+
stop_strings = [stop]
|
102
|
+
|
103
|
+
sample_request = sampler_pb2.SampleTextRequest(
|
104
|
+
prompt=prompt,
|
105
|
+
settings=sampler_pb2.SampleSettings(
|
106
|
+
max_len=max_tokens,
|
107
|
+
rng_seed=0,
|
108
|
+
temperature=max(temperature, 1e-7),
|
109
|
+
nucleus_p=1,
|
110
|
+
stop_strings=stop_strings,
|
111
|
+
),
|
112
|
+
)
|
113
|
+
stream = sampler.SampleText(sample_request)
|
114
|
+
response = "".join([x.text for x in stream])
|
115
|
+
return response
|
116
|
+
|
117
|
+
|
91
118
|
def call_generate_guidance(
|
92
119
|
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
|
93
120
|
):
|
@@ -228,6 +255,7 @@ def add_common_other_args_and_parse(parser):
|
|
228
255
|
"vllm",
|
229
256
|
"outlines",
|
230
257
|
"lightllm",
|
258
|
+
"ginfer",
|
231
259
|
"guidance",
|
232
260
|
"lmql",
|
233
261
|
"srt-raw",
|
@@ -248,6 +276,7 @@ def add_common_other_args_and_parse(parser):
|
|
248
276
|
"lightllm": 22000,
|
249
277
|
"lmql": 23000,
|
250
278
|
"srt-raw": 30000,
|
279
|
+
"ginfer": 9988,
|
251
280
|
}
|
252
281
|
args.port = default_port.get(args.backend, None)
|
253
282
|
return args
|
@@ -283,6 +312,8 @@ def _get_call_generate(args):
|
|
283
312
|
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
|
284
313
|
elif args.backend == "srt-raw":
|
285
314
|
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
|
315
|
+
elif args.backend == "ginfer":
|
316
|
+
return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
|
286
317
|
elif args.backend == "outlines":
|
287
318
|
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
|
288
319
|
elif args.backend == "guidance":
|
sglang/utils.py
CHANGED
@@ -2,7 +2,8 @@
|
|
2
2
|
|
3
3
|
import base64
|
4
4
|
import json
|
5
|
-
import
|
5
|
+
import logging
|
6
|
+
import signal
|
6
7
|
import sys
|
7
8
|
import threading
|
8
9
|
import traceback
|
@@ -14,6 +15,8 @@ from json import dumps
|
|
14
15
|
import numpy as np
|
15
16
|
import requests
|
16
17
|
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
17
20
|
|
18
21
|
def get_exception_traceback():
|
19
22
|
etype, value, tb = sys.exc_info()
|
@@ -93,8 +96,12 @@ def http_request(
|
|
93
96
|
data = None
|
94
97
|
else:
|
95
98
|
data = bytes(dumps(json), encoding="utf-8")
|
96
|
-
|
97
|
-
|
99
|
+
|
100
|
+
try:
|
101
|
+
resp = urllib.request.urlopen(req, data=data, cafile=verify)
|
102
|
+
return HttpResponse(resp)
|
103
|
+
except urllib.error.HTTPError as e:
|
104
|
+
return HttpResponse(e)
|
98
105
|
|
99
106
|
|
100
107
|
def encode_image_base64(image_path):
|
@@ -137,7 +144,8 @@ def encode_frame(frame):
|
|
137
144
|
|
138
145
|
|
139
146
|
def encode_video_base64(video_path, num_frames=16):
|
140
|
-
import cv2
|
147
|
+
import cv2 # pip install opencv-python-headless
|
148
|
+
|
141
149
|
cap = cv2.VideoCapture(video_path)
|
142
150
|
if not cap.isOpened():
|
143
151
|
raise IOError(f"Could not open video file:{video_path}")
|
@@ -242,3 +250,14 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
|
242
250
|
raise RuntimeError()
|
243
251
|
|
244
252
|
return ret_value[0]
|
253
|
+
|
254
|
+
|
255
|
+
def graceful_registry(sub_module_name):
|
256
|
+
def graceful_shutdown(signum, frame):
|
257
|
+
logger.info(
|
258
|
+
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
|
259
|
+
)
|
260
|
+
if signum == signal.SIGTERM:
|
261
|
+
logger.info(f"{sub_module_name} recive sigterm")
|
262
|
+
|
263
|
+
signal.signal(signal.SIGTERM, graceful_shutdown)
|