sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -1,25 +1,31 @@
|
|
1
1
|
"""Common utilities."""
|
2
2
|
|
3
3
|
import base64
|
4
|
+
import multiprocessing
|
5
|
+
import logging
|
4
6
|
import os
|
5
7
|
import random
|
6
8
|
import socket
|
7
|
-
import sys
|
8
9
|
import time
|
9
|
-
import traceback
|
10
10
|
from importlib.metadata import PackageNotFoundError, version
|
11
11
|
from io import BytesIO
|
12
12
|
from typing import List, Optional
|
13
13
|
|
14
14
|
import numpy as np
|
15
|
-
import
|
15
|
+
import psutil
|
16
16
|
import requests
|
17
|
+
import rpyc
|
17
18
|
import torch
|
19
|
+
import triton
|
20
|
+
from rpyc.utils.server import ThreadedServer
|
18
21
|
from fastapi.responses import JSONResponse
|
19
22
|
from packaging import version as pkg_version
|
20
|
-
from pydantic import BaseModel
|
21
23
|
from starlette.middleware.base import BaseHTTPMiddleware
|
22
24
|
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
23
29
|
show_time_cost = False
|
24
30
|
time_infos = {}
|
25
31
|
|
@@ -90,37 +96,49 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|
90
96
|
return wrapper
|
91
97
|
|
92
98
|
|
93
|
-
def
|
94
|
-
|
99
|
+
def get_available_gpu_memory(gpu_id, distributed=False):
|
100
|
+
"""
|
101
|
+
Get available memory for cuda:gpu_id device.
|
102
|
+
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
103
|
+
"""
|
104
|
+
num_gpus = torch.cuda.device_count()
|
105
|
+
assert gpu_id < num_gpus
|
95
106
|
|
96
|
-
torch.
|
97
|
-
|
98
|
-
|
107
|
+
if torch.cuda.current_device() != gpu_id:
|
108
|
+
print(
|
109
|
+
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
110
|
+
"which may cause useless memory allocation for torch CUDA context.",
|
111
|
+
)
|
112
|
+
|
113
|
+
torch.cuda.empty_cache()
|
114
|
+
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
99
115
|
|
116
|
+
if distributed:
|
117
|
+
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
118
|
+
torch.device("cuda", gpu_id)
|
119
|
+
)
|
120
|
+
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
121
|
+
free_gpu_memory = tensor.item()
|
100
122
|
|
101
|
-
|
102
|
-
port_list = []
|
103
|
-
for port in range(10000, 65536):
|
104
|
-
if port in used_list:
|
105
|
-
continue
|
123
|
+
return free_gpu_memory / (1 << 30)
|
106
124
|
|
107
|
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
108
|
-
try:
|
109
|
-
s.bind(("", port))
|
110
|
-
port_list.append(port)
|
111
|
-
except socket.error:
|
112
|
-
pass
|
113
125
|
|
114
|
-
|
115
|
-
|
116
|
-
|
126
|
+
def set_random_seed(seed: int) -> None:
|
127
|
+
"""Set the random seed for all libraries."""
|
128
|
+
random.seed(seed)
|
129
|
+
np.random.seed(seed)
|
130
|
+
torch.manual_seed(seed)
|
131
|
+
if torch.cuda.is_available():
|
132
|
+
torch.cuda.manual_seed_all(seed)
|
117
133
|
|
118
134
|
|
119
|
-
def
|
135
|
+
def is_port_available(port):
|
136
|
+
"""Return whether a port is available."""
|
120
137
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
121
138
|
try:
|
122
139
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
123
140
|
s.bind(("", port))
|
141
|
+
s.listen(1)
|
124
142
|
return True
|
125
143
|
except socket.error:
|
126
144
|
return False
|
@@ -130,41 +148,34 @@ def allocate_init_ports(
|
|
130
148
|
port: Optional[int] = None,
|
131
149
|
additional_ports: Optional[List[int]] = None,
|
132
150
|
tp_size: int = 1,
|
151
|
+
dp_size: int = 1,
|
133
152
|
):
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
|
143
|
-
port = new_port
|
144
|
-
|
145
|
-
# then we check on additional ports
|
146
|
-
additional_unique_ports = set(additional_ports) - {port}
|
147
|
-
# filter out ports that are already in use
|
148
|
-
can_use_ports = [port for port in additional_unique_ports if check_port(port)]
|
149
|
-
|
150
|
-
num_specified_ports = len(can_use_ports)
|
151
|
-
if num_specified_ports < 4 + tp_size:
|
152
|
-
addtional_can_use_ports = alloc_usable_network_port(
|
153
|
-
num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
|
154
|
-
)
|
155
|
-
can_use_ports.extend(addtional_can_use_ports)
|
153
|
+
"""Allocate ports for all connections."""
|
154
|
+
if additional_ports:
|
155
|
+
ret_ports = [port] + additional_ports
|
156
|
+
else:
|
157
|
+
ret_ports = [port]
|
158
|
+
|
159
|
+
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
160
|
+
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
156
161
|
|
157
|
-
|
158
|
-
|
162
|
+
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
163
|
+
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
164
|
+
while len(ret_ports) < num_ports_needed:
|
165
|
+
if cur_port not in ret_ports and is_port_available(cur_port):
|
166
|
+
ret_ports.append(cur_port)
|
167
|
+
cur_port += 1
|
159
168
|
|
169
|
+
if port is not None and ret_ports[0] != port:
|
170
|
+
logger.warn(
|
171
|
+
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
172
|
+
)
|
160
173
|
|
161
|
-
|
162
|
-
etype, value, tb = sys.exc_info()
|
163
|
-
err_str = "".join(traceback.format_exception(etype, value, tb))
|
164
|
-
return err_str
|
174
|
+
return ret_ports[0], ret_ports[1:num_ports_needed]
|
165
175
|
|
166
176
|
|
167
177
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
178
|
+
"""Get the logit bias for integer-only tokens."""
|
168
179
|
# a bug when model's vocab size > tokenizer.vocab_size
|
169
180
|
vocab_size = tokenizer.vocab_size
|
170
181
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
@@ -178,14 +189,11 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
|
178
189
|
|
179
190
|
def wrap_kernel_launcher(kernel):
|
180
191
|
"""A faster launcher for triton kernels."""
|
181
|
-
|
182
|
-
|
183
|
-
if dist.is_initialized():
|
184
|
-
rank = dist.get_rank()
|
185
|
-
else:
|
186
|
-
rank = 0
|
192
|
+
if int(triton.__version__.split(".")[0]) >= 3:
|
193
|
+
return None
|
187
194
|
|
188
|
-
|
195
|
+
gpu_id = torch.cuda.current_device()
|
196
|
+
kernels = kernel.cache[gpu_id].values()
|
189
197
|
kernel = next(iter(kernels))
|
190
198
|
|
191
199
|
# Different trition versions use different low-level names
|
@@ -245,20 +253,104 @@ def wrap_kernel_launcher(kernel):
|
|
245
253
|
|
246
254
|
|
247
255
|
def is_multimodal_model(model):
|
248
|
-
if isinstance(model, str):
|
249
|
-
return "llava" in model or "yi-vl" in model
|
250
256
|
from sglang.srt.model_config import ModelConfig
|
251
257
|
|
258
|
+
if isinstance(model, str):
|
259
|
+
model = model.lower()
|
260
|
+
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
261
|
+
|
252
262
|
if isinstance(model, ModelConfig):
|
253
263
|
model_path = model.path.lower()
|
254
|
-
return
|
255
|
-
|
264
|
+
return (
|
265
|
+
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
266
|
+
)
|
267
|
+
|
268
|
+
raise ValueError("unrecognized type")
|
269
|
+
|
270
|
+
|
271
|
+
def decode_video_base64(video_base64):
|
272
|
+
from PIL import Image
|
273
|
+
|
274
|
+
# Decode the base64 string
|
275
|
+
video_bytes = base64.b64decode(video_base64)
|
276
|
+
|
277
|
+
# Placeholder for the start indices of each PNG image
|
278
|
+
img_starts = []
|
279
|
+
|
280
|
+
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
|
281
|
+
|
282
|
+
assert frame_format in [
|
283
|
+
"PNG",
|
284
|
+
"JPEG",
|
285
|
+
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
|
286
|
+
|
287
|
+
if frame_format == "PNG":
|
288
|
+
# Find each PNG start signature to isolate images
|
289
|
+
i = 0
|
290
|
+
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
|
291
|
+
# Check if we found the start of a PNG file
|
292
|
+
if (
|
293
|
+
video_bytes[i] == 0x89
|
294
|
+
and video_bytes[i + 1] == 0x50
|
295
|
+
and video_bytes[i + 2] == 0x4E
|
296
|
+
and video_bytes[i + 3] == 0x47
|
297
|
+
and video_bytes[i + 4] == 0x0D
|
298
|
+
and video_bytes[i + 5] == 0x0A
|
299
|
+
and video_bytes[i + 6] == 0x1A
|
300
|
+
and video_bytes[i + 7] == 0x0A
|
301
|
+
):
|
302
|
+
img_starts.append(i)
|
303
|
+
i += 8 # Skip the PNG signature
|
304
|
+
else:
|
305
|
+
i += 1
|
306
|
+
else:
|
307
|
+
# Find each JPEG start (0xFFD8) to isolate images
|
308
|
+
i = 0
|
309
|
+
while (
|
310
|
+
i < len(video_bytes) - 1
|
311
|
+
): # Adjusted for the length of the JPEG SOI signature
|
312
|
+
# Check if we found the start of a JPEG file
|
313
|
+
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
|
314
|
+
img_starts.append(i)
|
315
|
+
# Move to the next byte to continue searching for the next image start
|
316
|
+
i += 2
|
317
|
+
else:
|
318
|
+
i += 1
|
319
|
+
|
320
|
+
frames = []
|
321
|
+
for start_idx in img_starts:
|
322
|
+
# Assuming each image is back-to-back, the end of one image is the start of another
|
323
|
+
# The last image goes until the end of the byte string
|
324
|
+
end_idx = (
|
325
|
+
img_starts[img_starts.index(start_idx) + 1]
|
326
|
+
if img_starts.index(start_idx) + 1 < len(img_starts)
|
327
|
+
else len(video_bytes)
|
328
|
+
)
|
329
|
+
img_bytes = video_bytes[start_idx:end_idx]
|
330
|
+
|
331
|
+
# Convert bytes to a PIL Image
|
332
|
+
img = Image.open(BytesIO(img_bytes))
|
333
|
+
|
334
|
+
# Convert PIL Image to a NumPy array
|
335
|
+
frame = np.array(img)
|
336
|
+
|
337
|
+
# Append the frame to the list of frames
|
338
|
+
frames.append(frame)
|
339
|
+
|
340
|
+
# Ensure there's at least one frame to avoid errors with np.stack
|
341
|
+
if frames:
|
342
|
+
return np.stack(frames, axis=0), img.size
|
343
|
+
else:
|
344
|
+
return np.array([]), (
|
345
|
+
0,
|
346
|
+
0,
|
347
|
+
) # Return an empty array and size tuple if no frames were found
|
256
348
|
|
257
349
|
|
258
350
|
def load_image(image_file):
|
259
351
|
from PIL import Image
|
260
352
|
|
261
|
-
image = None
|
353
|
+
image = image_size = None
|
262
354
|
|
263
355
|
if image_file.startswith("http://") or image_file.startswith("https://"):
|
264
356
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
@@ -269,10 +361,71 @@ def load_image(image_file):
|
|
269
361
|
elif image_file.startswith("data:"):
|
270
362
|
image_file = image_file.split(",")[1]
|
271
363
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
364
|
+
elif image_file.startswith("video:"):
|
365
|
+
image_file = image_file.replace("video:", "")
|
366
|
+
image, image_size = decode_video_base64(image_file)
|
272
367
|
else:
|
273
368
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
274
369
|
|
275
|
-
return image
|
370
|
+
return image, image_size
|
371
|
+
|
372
|
+
|
373
|
+
def init_rpyc_service(service: rpyc.Service, port: int):
|
374
|
+
t = ThreadedServer(
|
375
|
+
service=service,
|
376
|
+
port=port,
|
377
|
+
protocol_config={
|
378
|
+
"allow_public_attrs": True,
|
379
|
+
"allow_pickle": True,
|
380
|
+
"sync_request_timeout": 3600
|
381
|
+
},
|
382
|
+
)
|
383
|
+
t.logger.setLevel(logging.WARN)
|
384
|
+
t.start()
|
385
|
+
|
386
|
+
|
387
|
+
def connect_to_rpyc_service(port, host="localhost"):
|
388
|
+
time.sleep(1)
|
389
|
+
|
390
|
+
repeat_count = 0
|
391
|
+
while repeat_count < 20:
|
392
|
+
try:
|
393
|
+
con = rpyc.connect(
|
394
|
+
host,
|
395
|
+
port,
|
396
|
+
config={
|
397
|
+
"allow_public_attrs": True,
|
398
|
+
"allow_pickle": True,
|
399
|
+
"sync_request_timeout": 3600
|
400
|
+
},
|
401
|
+
)
|
402
|
+
break
|
403
|
+
except ConnectionRefusedError:
|
404
|
+
time.sleep(1)
|
405
|
+
repeat_count += 1
|
406
|
+
if repeat_count == 20:
|
407
|
+
raise RuntimeError("init rpc env error!")
|
408
|
+
|
409
|
+
return con.root
|
410
|
+
|
411
|
+
|
412
|
+
def start_rpyc_process(service: rpyc.Service, port: int):
|
413
|
+
# Return the proxy and the process
|
414
|
+
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
|
415
|
+
proc.start()
|
416
|
+
proxy = connect_to_rpyc_service(port)
|
417
|
+
assert proc.is_alive()
|
418
|
+
return proxy, proc
|
419
|
+
|
420
|
+
|
421
|
+
def suppress_other_loggers():
|
422
|
+
from vllm.logger import logger as vllm_default_logger
|
423
|
+
|
424
|
+
vllm_default_logger.setLevel(logging.WARN)
|
425
|
+
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
426
|
+
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN)
|
427
|
+
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
428
|
+
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
276
429
|
|
277
430
|
|
278
431
|
def assert_pkg_version(pkg: str, min_version: str):
|
@@ -284,7 +437,30 @@ def assert_pkg_version(pkg: str, min_version: str):
|
|
284
437
|
f"is less than the minimum required version {min_version}"
|
285
438
|
)
|
286
439
|
except PackageNotFoundError:
|
287
|
-
raise Exception(
|
440
|
+
raise Exception(
|
441
|
+
f"{pkg} with minimum required version {min_version} is not installed"
|
442
|
+
)
|
443
|
+
|
444
|
+
|
445
|
+
def kill_parent_process():
|
446
|
+
"""Kill the parent process and all children of the parent process."""
|
447
|
+
current_process = psutil.Process()
|
448
|
+
parent_process = current_process.parent()
|
449
|
+
children = current_process.children(recursive=True)
|
450
|
+
for child in children:
|
451
|
+
if child.pid != current_process.pid:
|
452
|
+
os.kill(child.pid, 9)
|
453
|
+
os.kill(parent_process.pid, 9)
|
454
|
+
|
455
|
+
|
456
|
+
def monkey_patch_vllm_p2p_access_check():
|
457
|
+
"""
|
458
|
+
Monkey patch the slow p2p access check in vllm.
|
459
|
+
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
460
|
+
"""
|
461
|
+
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
462
|
+
|
463
|
+
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
288
464
|
|
289
465
|
|
290
466
|
API_KEY_HEADER_NAME = "X-API-Key"
|
@@ -306,12 +482,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|
306
482
|
response = await call_next(request)
|
307
483
|
return response
|
308
484
|
|
309
|
-
|
310
|
-
# FIXME: Remove this once we drop support for pydantic 1.x
|
311
|
-
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
312
|
-
|
313
|
-
|
314
|
-
def jsonify_pydantic_model(obj: BaseModel):
|
315
|
-
if IS_PYDANTIC_1:
|
316
|
-
return obj.json(ensure_ascii=False)
|
317
|
-
return obj.model_dump_json()
|
sglang/test/test_programs.py
CHANGED
@@ -304,6 +304,7 @@ def test_image_qa():
|
|
304
304
|
temperature=0,
|
305
305
|
max_new_tokens=64,
|
306
306
|
)
|
307
|
+
|
307
308
|
assert (
|
308
309
|
"taxi" in state.messages()[-1]["content"]
|
309
310
|
or "car" in state.messages()[-1]["content"]
|
@@ -349,3 +350,46 @@ def test_regex():
|
|
349
350
|
state = regex_gen.run()
|
350
351
|
answer = state["answer"]
|
351
352
|
assert re.match(regex, answer)
|
353
|
+
|
354
|
+
|
355
|
+
def test_completion_speculative():
|
356
|
+
@sgl.function(num_api_spec_tokens=64)
|
357
|
+
def gen_character_spec(s):
|
358
|
+
s += "Construct a character within the following format:\n"
|
359
|
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
360
|
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
361
|
+
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
362
|
+
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
363
|
+
|
364
|
+
|
365
|
+
@sgl.function
|
366
|
+
def gen_character_no_spec(s):
|
367
|
+
s += "Construct a character within the following format:\n"
|
368
|
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
369
|
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
370
|
+
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
|
371
|
+
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
372
|
+
|
373
|
+
token_usage = sgl.global_config.default_backend.token_usage
|
374
|
+
|
375
|
+
token_usage.reset()
|
376
|
+
gen_character_spec().sync()
|
377
|
+
usage_with_spec = token_usage.prompt_tokens
|
378
|
+
|
379
|
+
token_usage.reset()
|
380
|
+
gen_character_no_spec().sync()
|
381
|
+
usage_with_no_spec = token_usage.prompt_tokens
|
382
|
+
|
383
|
+
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
|
384
|
+
|
385
|
+
|
386
|
+
def test_chat_completion_speculative():
|
387
|
+
@sgl.function(num_api_spec_tokens=256)
|
388
|
+
def gen_character_spec(s):
|
389
|
+
s += sgl.system("You are a helpful assistant.")
|
390
|
+
s += sgl.user("Construct a character within the following format:")
|
391
|
+
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
|
392
|
+
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
393
|
+
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
394
|
+
|
395
|
+
gen_character_spec().sync()
|
sglang/test/test_utils.py
CHANGED
@@ -9,7 +9,7 @@ import requests
|
|
9
9
|
from sglang.backend.openai import OpenAI
|
10
10
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
11
11
|
from sglang.global_config import global_config
|
12
|
-
from sglang.
|
12
|
+
from sglang.utils import get_exception_traceback
|
13
13
|
|
14
14
|
|
15
15
|
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
@@ -88,6 +88,33 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
|
|
88
88
|
return pred
|
89
89
|
|
90
90
|
|
91
|
+
def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
|
92
|
+
import grpc
|
93
|
+
from ginfer import sampler_pb2, sampler_pb2_grpc
|
94
|
+
|
95
|
+
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
|
96
|
+
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
|
97
|
+
|
98
|
+
if stop is None:
|
99
|
+
stop_strings = None
|
100
|
+
else:
|
101
|
+
stop_strings = [stop]
|
102
|
+
|
103
|
+
sample_request = sampler_pb2.SampleTextRequest(
|
104
|
+
prompt=prompt,
|
105
|
+
settings=sampler_pb2.SampleSettings(
|
106
|
+
max_len=max_tokens,
|
107
|
+
rng_seed=0,
|
108
|
+
temperature=max(temperature, 1e-7),
|
109
|
+
nucleus_p=1,
|
110
|
+
stop_strings=stop_strings,
|
111
|
+
),
|
112
|
+
)
|
113
|
+
stream = sampler.SampleText(sample_request)
|
114
|
+
response = "".join([x.text for x in stream])
|
115
|
+
return response
|
116
|
+
|
117
|
+
|
91
118
|
def call_generate_guidance(
|
92
119
|
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
|
93
120
|
):
|
@@ -228,6 +255,7 @@ def add_common_other_args_and_parse(parser):
|
|
228
255
|
"vllm",
|
229
256
|
"outlines",
|
230
257
|
"lightllm",
|
258
|
+
"ginfer",
|
231
259
|
"guidance",
|
232
260
|
"lmql",
|
233
261
|
"srt-raw",
|
@@ -248,6 +276,7 @@ def add_common_other_args_and_parse(parser):
|
|
248
276
|
"lightllm": 22000,
|
249
277
|
"lmql": 23000,
|
250
278
|
"srt-raw": 30000,
|
279
|
+
"ginfer": 9988,
|
251
280
|
}
|
252
281
|
args.port = default_port.get(args.backend, None)
|
253
282
|
return args
|
@@ -283,6 +312,8 @@ def _get_call_generate(args):
|
|
283
312
|
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
|
284
313
|
elif args.backend == "srt-raw":
|
285
314
|
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
|
315
|
+
elif args.backend == "ginfer":
|
316
|
+
return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
|
286
317
|
elif args.backend == "outlines":
|
287
318
|
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
|
288
319
|
elif args.backend == "guidance":
|
sglang/utils.py
CHANGED
@@ -2,40 +2,27 @@
|
|
2
2
|
|
3
3
|
import base64
|
4
4
|
import json
|
5
|
+
import logging
|
6
|
+
import signal
|
7
|
+
import sys
|
5
8
|
import threading
|
9
|
+
import traceback
|
6
10
|
import urllib.request
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
7
12
|
from io import BytesIO
|
8
13
|
from json import dumps
|
9
14
|
|
15
|
+
import numpy as np
|
10
16
|
import requests
|
11
17
|
|
12
18
|
|
13
|
-
|
14
|
-
"""
|
15
|
-
Get available memory for cuda:gpu_id device.
|
16
|
-
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
17
|
-
"""
|
18
|
-
import torch
|
19
|
+
logger = logging.getLogger(__name__)
|
19
20
|
|
20
|
-
num_gpus = torch.cuda.device_count()
|
21
|
-
assert gpu_id < num_gpus
|
22
21
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
)
|
28
|
-
|
29
|
-
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
30
|
-
|
31
|
-
if distributed:
|
32
|
-
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
33
|
-
torch.device("cuda", gpu_id)
|
34
|
-
)
|
35
|
-
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
36
|
-
free_gpu_memory = tensor.item()
|
37
|
-
|
38
|
-
return free_gpu_memory / (1 << 30)
|
22
|
+
def get_exception_traceback():
|
23
|
+
etype, value, tb = sys.exc_info()
|
24
|
+
err_str = "".join(traceback.format_exception(etype, value, tb))
|
25
|
+
return err_str
|
39
26
|
|
40
27
|
|
41
28
|
def is_same_type(values):
|
@@ -110,8 +97,12 @@ def http_request(
|
|
110
97
|
data = None
|
111
98
|
else:
|
112
99
|
data = bytes(dumps(json), encoding="utf-8")
|
113
|
-
|
114
|
-
|
100
|
+
|
101
|
+
try:
|
102
|
+
resp = urllib.request.urlopen(req, data=data, cafile=verify)
|
103
|
+
return HttpResponse(resp)
|
104
|
+
except urllib.error.HTTPError as e:
|
105
|
+
return HttpResponse(e)
|
115
106
|
|
116
107
|
|
117
108
|
def encode_image_base64(image_path):
|
@@ -130,6 +121,75 @@ def encode_image_base64(image_path):
|
|
130
121
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
131
122
|
|
132
123
|
|
124
|
+
def encode_frame(frame):
|
125
|
+
import cv2 # pip install opencv-python-headless
|
126
|
+
from PIL import Image
|
127
|
+
|
128
|
+
# Convert the frame to RGB (OpenCV uses BGR by default)
|
129
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
130
|
+
|
131
|
+
# Convert the frame to PIL Image to easily convert to bytes
|
132
|
+
im_pil = Image.fromarray(frame)
|
133
|
+
|
134
|
+
# Convert to bytes
|
135
|
+
buffered = BytesIO()
|
136
|
+
|
137
|
+
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
|
138
|
+
|
139
|
+
im_pil.save(buffered, format="PNG")
|
140
|
+
|
141
|
+
frame_bytes = buffered.getvalue()
|
142
|
+
|
143
|
+
# Return the bytes of the frame
|
144
|
+
return frame_bytes
|
145
|
+
|
146
|
+
|
147
|
+
def encode_video_base64(video_path, num_frames=16):
|
148
|
+
import cv2 # pip install opencv-python-headless
|
149
|
+
|
150
|
+
cap = cv2.VideoCapture(video_path)
|
151
|
+
if not cap.isOpened():
|
152
|
+
raise IOError(f"Could not open video file:{video_path}")
|
153
|
+
|
154
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
155
|
+
print(f"target_frames: {num_frames}")
|
156
|
+
|
157
|
+
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
158
|
+
|
159
|
+
frames = []
|
160
|
+
for i in range(total_frames):
|
161
|
+
ret, frame = cap.read()
|
162
|
+
if ret:
|
163
|
+
frames.append(frame)
|
164
|
+
else:
|
165
|
+
# Handle the case where the frame could not be read
|
166
|
+
# print(f"Warning: Could not read frame at index {i}.")
|
167
|
+
pass
|
168
|
+
|
169
|
+
cap.release()
|
170
|
+
|
171
|
+
# Safely select frames based on frame_indices, avoiding IndexError
|
172
|
+
frames = [frames[i] for i in frame_indices if i < len(frames)]
|
173
|
+
|
174
|
+
# If there are not enough frames, duplicate the last frame until we reach the target
|
175
|
+
while len(frames) < num_frames:
|
176
|
+
frames.append(frames[-1])
|
177
|
+
|
178
|
+
# Use ThreadPoolExecutor to process and encode frames in parallel
|
179
|
+
with ThreadPoolExecutor() as executor:
|
180
|
+
encoded_frames = list(executor.map(encode_frame, frames))
|
181
|
+
|
182
|
+
# encoded_frames = list(map(encode_frame, frames))
|
183
|
+
|
184
|
+
# Concatenate all frames bytes
|
185
|
+
video_bytes = b"".join(encoded_frames)
|
186
|
+
|
187
|
+
# Encode the concatenated bytes to base64
|
188
|
+
video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
|
189
|
+
|
190
|
+
return video_base64
|
191
|
+
|
192
|
+
|
133
193
|
def _is_chinese_char(cp):
|
134
194
|
"""Checks whether CP is the codepoint of a CJK character."""
|
135
195
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
@@ -191,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
|
191
251
|
raise RuntimeError()
|
192
252
|
|
193
253
|
return ret_value[0]
|
254
|
+
|
255
|
+
|
256
|
+
def graceful_registry(sub_module_name):
|
257
|
+
def graceful_shutdown(signum, frame):
|
258
|
+
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
|
259
|
+
if signum == signal.SIGTERM:
|
260
|
+
logger.info(f"{sub_module_name} recive sigterm")
|
261
|
+
|
262
|
+
signal.signal(signal.SIGTERM, graceful_shutdown)
|