sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +57 -2
- sglang/api.py +8 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +83 -2
- sglang/lang/interpreter.py +92 -35
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +6 -4
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +10 -2
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +27 -3
- sglang/srt/managers/router/infer_batch.py +97 -48
- sglang/srt/managers/router/manager.py +11 -8
- sglang/srt/managers/router/model_rpc.py +169 -90
- sglang/srt/managers/router/model_runner.py +110 -166
- sglang/srt/managers/router/radix_cache.py +89 -51
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +110 -33
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +11 -0
- sglang/srt/models/commandr.py +372 -0
- sglang/srt/models/dbrx.py +412 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +24 -25
- sglang/srt/models/llama2.py +25 -26
- sglang/srt/models/llava.py +8 -10
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +29 -33
- sglang/srt/models/qwen.py +34 -25
- sglang/srt/models/qwen2.py +25 -26
- sglang/srt/models/stablelm.py +26 -26
- sglang/srt/models/yivl.py +3 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +91 -456
- sglang/srt/server_args.py +79 -49
- sglang/srt/utils.py +212 -47
- sglang/srt/weight_utils.py +417 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- sglang/utils.py +77 -26
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
"""The arguments of the server."""
|
2
|
+
|
1
3
|
import argparse
|
2
4
|
import dataclasses
|
3
5
|
from typing import List, Optional, Union
|
@@ -5,34 +7,47 @@ from typing import List, Optional, Union
|
|
5
7
|
|
6
8
|
@dataclasses.dataclass
|
7
9
|
class ServerArgs:
|
10
|
+
# Model and tokenizer
|
8
11
|
model_path: str
|
9
12
|
tokenizer_path: Optional[str] = None
|
10
|
-
host: str = "127.0.0.1"
|
11
|
-
port: int = 30000
|
12
|
-
additional_ports: Optional[Union[List[int], int]] = None
|
13
13
|
load_format: str = "auto"
|
14
14
|
tokenizer_mode: str = "auto"
|
15
15
|
chat_template: Optional[str] = None
|
16
16
|
trust_remote_code: bool = True
|
17
|
+
context_length: Optional[int] = None
|
18
|
+
|
19
|
+
# Port
|
20
|
+
host: str = "127.0.0.1"
|
21
|
+
port: int = 30000
|
22
|
+
additional_ports: Optional[Union[List[int], int]] = None
|
23
|
+
|
24
|
+
# Memory and scheduling
|
17
25
|
mem_fraction_static: Optional[float] = None
|
18
26
|
max_prefill_num_token: Optional[int] = None
|
19
|
-
context_length: Optional[int] = None
|
20
|
-
tp_size: int = 1
|
21
27
|
schedule_heuristic: str = "lpm"
|
22
28
|
schedule_conservativeness: float = 1.0
|
23
|
-
|
24
|
-
|
29
|
+
|
30
|
+
# Other runtime options
|
31
|
+
tp_size: int = 1
|
25
32
|
stream_interval: int = 8
|
33
|
+
random_seed: int = 42
|
34
|
+
|
35
|
+
# Logging
|
36
|
+
log_level: str = "info"
|
37
|
+
log_requests: bool = False
|
26
38
|
disable_log_stats: bool = False
|
27
39
|
log_stats_interval: int = 10
|
28
|
-
|
40
|
+
show_time_cost: bool = False
|
29
41
|
|
30
|
-
#
|
31
|
-
|
42
|
+
# Other
|
43
|
+
api_key: str = ""
|
44
|
+
|
45
|
+
# Optimization/debug options
|
32
46
|
enable_flashinfer: bool = False
|
47
|
+
attention_reduce_in_fp32: bool = False
|
48
|
+
disable_radix_cache: bool = False
|
33
49
|
disable_regex_jump_forward: bool = False
|
34
50
|
disable_disk_cache: bool = False
|
35
|
-
api_key: str = ""
|
36
51
|
|
37
52
|
def __post_init__(self):
|
38
53
|
if self.tokenizer_path is None:
|
@@ -65,15 +80,18 @@ class ServerArgs:
|
|
65
80
|
default=ServerArgs.tokenizer_path,
|
66
81
|
help="The path of the tokenizer.",
|
67
82
|
)
|
68
|
-
parser.add_argument(
|
69
|
-
|
70
|
-
|
83
|
+
parser.add_argument(
|
84
|
+
"--host", type=str, default=ServerArgs.host, help="The host of the server."
|
85
|
+
)
|
86
|
+
parser.add_argument(
|
87
|
+
"--port", type=int, default=ServerArgs.port, help="The port of the server."
|
88
|
+
)
|
71
89
|
parser.add_argument(
|
72
90
|
"--additional-ports",
|
73
91
|
type=int,
|
74
92
|
nargs="*",
|
75
93
|
default=[],
|
76
|
-
help="Additional ports specified for
|
94
|
+
help="Additional ports specified for the server.",
|
77
95
|
)
|
78
96
|
parser.add_argument(
|
79
97
|
"--load-format",
|
@@ -111,6 +129,12 @@ class ServerArgs:
|
|
111
129
|
action="store_true",
|
112
130
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
113
131
|
)
|
132
|
+
parser.add_argument(
|
133
|
+
"--context-length",
|
134
|
+
type=int,
|
135
|
+
default=ServerArgs.context_length,
|
136
|
+
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
137
|
+
)
|
114
138
|
parser.add_argument(
|
115
139
|
"--mem-fraction-static",
|
116
140
|
type=float,
|
@@ -123,23 +147,12 @@ class ServerArgs:
|
|
123
147
|
default=ServerArgs.max_prefill_num_token,
|
124
148
|
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
125
149
|
)
|
126
|
-
parser.add_argument(
|
127
|
-
"--context-length",
|
128
|
-
type=int,
|
129
|
-
default=ServerArgs.context_length,
|
130
|
-
help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
|
131
|
-
)
|
132
|
-
parser.add_argument(
|
133
|
-
"--tp-size",
|
134
|
-
type=int,
|
135
|
-
default=ServerArgs.tp_size,
|
136
|
-
help="Tensor parallelism degree.",
|
137
|
-
)
|
138
150
|
parser.add_argument(
|
139
151
|
"--schedule-heuristic",
|
140
152
|
type=str,
|
141
153
|
default=ServerArgs.schedule_heuristic,
|
142
|
-
|
154
|
+
choices=["lpm", "random", "fcfs", "dfs-weight"],
|
155
|
+
help="Scheduling Heuristic.",
|
143
156
|
)
|
144
157
|
parser.add_argument(
|
145
158
|
"--schedule-conservativeness",
|
@@ -148,15 +161,10 @@ class ServerArgs:
|
|
148
161
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
149
162
|
)
|
150
163
|
parser.add_argument(
|
151
|
-
"--
|
164
|
+
"--tp-size",
|
152
165
|
type=int,
|
153
|
-
default=ServerArgs.
|
154
|
-
help="
|
155
|
-
)
|
156
|
-
parser.add_argument(
|
157
|
-
"--attention-reduce-in-fp32",
|
158
|
-
action="store_true",
|
159
|
-
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
166
|
+
default=ServerArgs.tp_size,
|
167
|
+
help="Tensor parallelism size.",
|
160
168
|
)
|
161
169
|
parser.add_argument(
|
162
170
|
"--stream-interval",
|
@@ -164,11 +172,22 @@ class ServerArgs:
|
|
164
172
|
default=ServerArgs.stream_interval,
|
165
173
|
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
166
174
|
)
|
175
|
+
parser.add_argument(
|
176
|
+
"--random-seed",
|
177
|
+
type=int,
|
178
|
+
default=ServerArgs.random_seed,
|
179
|
+
help="Random seed.",
|
180
|
+
)
|
167
181
|
parser.add_argument(
|
168
182
|
"--log-level",
|
169
183
|
type=str,
|
170
184
|
default=ServerArgs.log_level,
|
171
|
-
help="
|
185
|
+
help="Logging level",
|
186
|
+
)
|
187
|
+
parser.add_argument(
|
188
|
+
"--log-requests",
|
189
|
+
action="store_true",
|
190
|
+
help="Log all requests",
|
172
191
|
)
|
173
192
|
parser.add_argument(
|
174
193
|
"--disable-log-stats",
|
@@ -181,17 +200,34 @@ class ServerArgs:
|
|
181
200
|
default=ServerArgs.log_stats_interval,
|
182
201
|
help="Log stats interval in second.",
|
183
202
|
)
|
184
|
-
# optional modes
|
185
203
|
parser.add_argument(
|
186
|
-
"--
|
204
|
+
"--show-time-cost",
|
187
205
|
action="store_true",
|
188
|
-
help="
|
206
|
+
help="Show time cost of custom marks",
|
207
|
+
)
|
208
|
+
parser.add_argument(
|
209
|
+
"--api-key",
|
210
|
+
type=str,
|
211
|
+
default=ServerArgs.api_key,
|
212
|
+
help="Set API key of the server",
|
189
213
|
)
|
214
|
+
|
215
|
+
# Optimization/debug options
|
190
216
|
parser.add_argument(
|
191
217
|
"--enable-flashinfer",
|
192
218
|
action="store_true",
|
193
219
|
help="Enable flashinfer inference kernels",
|
194
220
|
)
|
221
|
+
parser.add_argument(
|
222
|
+
"--attention-reduce-in-fp32",
|
223
|
+
action="store_true",
|
224
|
+
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
225
|
+
)
|
226
|
+
parser.add_argument(
|
227
|
+
"--disable-radix-cache",
|
228
|
+
action="store_true",
|
229
|
+
help="Disable RadixAttention",
|
230
|
+
)
|
195
231
|
parser.add_argument(
|
196
232
|
"--disable-regex-jump-forward",
|
197
233
|
action="store_true",
|
@@ -202,12 +238,6 @@ class ServerArgs:
|
|
202
238
|
action="store_true",
|
203
239
|
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
204
240
|
)
|
205
|
-
parser.add_argument(
|
206
|
-
"--api-key",
|
207
|
-
type=str,
|
208
|
-
default=ServerArgs.api_key,
|
209
|
-
help="Set API Key",
|
210
|
-
)
|
211
241
|
|
212
242
|
@classmethod
|
213
243
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -217,13 +247,13 @@ class ServerArgs:
|
|
217
247
|
def url(self):
|
218
248
|
return f"http://{self.host}:{self.port}"
|
219
249
|
|
220
|
-
def
|
250
|
+
def print_mode_args(self):
|
221
251
|
return (
|
222
|
-
f"disable_radix_cache={self.disable_radix_cache}, "
|
223
252
|
f"enable_flashinfer={self.enable_flashinfer}, "
|
253
|
+
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
254
|
+
f"disable_radix_cache={self.disable_radix_cache}, "
|
224
255
|
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
225
256
|
f"disable_disk_cache={self.disable_disk_cache}, "
|
226
|
-
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
227
257
|
)
|
228
258
|
|
229
259
|
|
sglang/srt/utils.py
CHANGED
@@ -1,58 +1,74 @@
|
|
1
|
+
"""Common utilities."""
|
2
|
+
|
1
3
|
import base64
|
2
4
|
import os
|
3
5
|
import random
|
4
6
|
import socket
|
5
|
-
import sys
|
6
7
|
import time
|
7
|
-
import
|
8
|
+
from importlib.metadata import PackageNotFoundError, version
|
8
9
|
from io import BytesIO
|
9
10
|
from typing import List, Optional
|
10
11
|
|
11
12
|
import numpy as np
|
13
|
+
import pydantic
|
12
14
|
import requests
|
13
15
|
import torch
|
14
|
-
|
16
|
+
from fastapi.responses import JSONResponse
|
17
|
+
from packaging import version as pkg_version
|
18
|
+
from pydantic import BaseModel
|
19
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
15
20
|
|
16
|
-
|
21
|
+
from sglang.utils import get_exception_traceback
|
17
22
|
|
23
|
+
show_time_cost = False
|
24
|
+
time_infos = {}
|
18
25
|
|
19
|
-
def mark_cost_time(func_name):
|
20
|
-
def inner_func(func):
|
21
|
-
def time_func(*args, **kwargs):
|
22
|
-
if dist.get_rank() in [0, 1] and is_show_cost_time:
|
23
|
-
torch.cuda.synchronize()
|
24
|
-
start_time = time.time()
|
25
|
-
ans = func(*args, **kwargs)
|
26
|
-
torch.cuda.synchronize()
|
27
|
-
print(func_name, "cost time:", (time.time() - start_time) * 1000)
|
28
|
-
return ans
|
29
|
-
else:
|
30
|
-
torch.cuda.synchronize()
|
31
|
-
ans = func(*args, **kwargs)
|
32
|
-
torch.cuda.synchronize()
|
33
|
-
return ans
|
34
26
|
|
35
|
-
|
27
|
+
def enable_show_time_cost():
|
28
|
+
global show_time_cost
|
29
|
+
show_time_cost = True
|
36
30
|
|
37
|
-
return inner_func
|
38
31
|
|
32
|
+
class TimeInfo:
|
33
|
+
def __init__(self, name, interval=0.1, color=0, indent=0):
|
34
|
+
self.name = name
|
35
|
+
self.interval = interval
|
36
|
+
self.color = color
|
37
|
+
self.indent = indent
|
38
|
+
|
39
|
+
self.acc_time = 0
|
40
|
+
self.last_acc_time = 0
|
41
|
+
|
42
|
+
def check(self):
|
43
|
+
if self.acc_time - self.last_acc_time > self.interval:
|
44
|
+
self.last_acc_time = self.acc_time
|
45
|
+
return True
|
46
|
+
return False
|
39
47
|
|
40
|
-
|
48
|
+
def pretty_print(self):
|
49
|
+
print(f"\x1b[{self.color}m", end="")
|
50
|
+
print("-" * self.indent * 2, end="")
|
51
|
+
print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
|
41
52
|
|
42
53
|
|
43
|
-
def mark_start(
|
54
|
+
def mark_start(name, interval=0.1, color=0, indent=0):
|
55
|
+
global time_infos, show_time_cost
|
56
|
+
if not show_time_cost:
|
57
|
+
return
|
44
58
|
torch.cuda.synchronize()
|
45
|
-
|
46
|
-
|
47
|
-
|
59
|
+
if time_infos.get(name, None) is None:
|
60
|
+
time_infos[name] = TimeInfo(name, interval, color, indent)
|
61
|
+
time_infos[name].acc_time -= time.time()
|
48
62
|
|
49
63
|
|
50
|
-
def mark_end(
|
64
|
+
def mark_end(name):
|
65
|
+
global time_infos, show_time_cost
|
66
|
+
if not show_time_cost:
|
67
|
+
return
|
51
68
|
torch.cuda.synchronize()
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
print(f"cost {key}:", cost_time)
|
69
|
+
time_infos[name].acc_time += time.time()
|
70
|
+
if time_infos[name].check():
|
71
|
+
time_infos[name].pretty_print()
|
56
72
|
|
57
73
|
|
58
74
|
def calculate_time(show=False, min_cost_ms=0.0):
|
@@ -74,6 +90,32 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|
74
90
|
return wrapper
|
75
91
|
|
76
92
|
|
93
|
+
def get_available_gpu_memory(gpu_id, distributed=True):
|
94
|
+
"""
|
95
|
+
Get available memory for cuda:gpu_id device.
|
96
|
+
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
97
|
+
"""
|
98
|
+
num_gpus = torch.cuda.device_count()
|
99
|
+
assert gpu_id < num_gpus
|
100
|
+
|
101
|
+
if torch.cuda.current_device() != gpu_id:
|
102
|
+
print(
|
103
|
+
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
104
|
+
"which may cause useless memory allocation for torch CUDA context.",
|
105
|
+
)
|
106
|
+
|
107
|
+
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
108
|
+
|
109
|
+
if distributed:
|
110
|
+
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
111
|
+
torch.device("cuda", gpu_id)
|
112
|
+
)
|
113
|
+
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
114
|
+
free_gpu_memory = tensor.item()
|
115
|
+
|
116
|
+
return free_gpu_memory / (1 << 30)
|
117
|
+
|
118
|
+
|
77
119
|
def set_random_seed(seed: int) -> None:
|
78
120
|
random.seed(seed)
|
79
121
|
|
@@ -89,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
|
|
89
131
|
continue
|
90
132
|
|
91
133
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
134
|
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
92
135
|
try:
|
93
136
|
s.bind(("", port))
|
137
|
+
s.listen(1) # Attempt to listen on the port
|
94
138
|
port_list.append(port)
|
95
139
|
except socket.error:
|
96
|
-
pass
|
140
|
+
pass # If any error occurs, this port is not usable
|
97
141
|
|
98
142
|
if len(port_list) == num:
|
99
143
|
return port_list
|
@@ -110,7 +154,7 @@ def check_port(port):
|
|
110
154
|
return False
|
111
155
|
|
112
156
|
|
113
|
-
def
|
157
|
+
def allocate_init_ports(
|
114
158
|
port: Optional[int] = None,
|
115
159
|
additional_ports: Optional[List[int]] = None,
|
116
160
|
tp_size: int = 1,
|
@@ -142,15 +186,7 @@ def handle_port_init(
|
|
142
186
|
return port, additional_ports
|
143
187
|
|
144
188
|
|
145
|
-
def get_exception_traceback():
|
146
|
-
etype, value, tb = sys.exc_info()
|
147
|
-
err_str = "".join(traceback.format_exception(etype, value, tb))
|
148
|
-
return err_str
|
149
|
-
|
150
|
-
|
151
189
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
152
|
-
from transformers import LlamaTokenizer, LlamaTokenizerFast
|
153
|
-
|
154
190
|
# a bug when model's vocab size > tokenizer.vocab_size
|
155
191
|
vocab_size = tokenizer.vocab_size
|
156
192
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
@@ -231,20 +267,102 @@ def wrap_kernel_launcher(kernel):
|
|
231
267
|
|
232
268
|
|
233
269
|
def is_multimodal_model(model):
|
234
|
-
if isinstance(model, str):
|
235
|
-
return "llava" in model or "yi-vl" in model
|
236
270
|
from sglang.srt.model_config import ModelConfig
|
237
271
|
|
272
|
+
if isinstance(model, str):
|
273
|
+
model = model.lower()
|
274
|
+
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
275
|
+
|
238
276
|
if isinstance(model, ModelConfig):
|
239
277
|
model_path = model.path.lower()
|
240
|
-
return "llava" in model_path or "yi-vl" in model_path
|
241
|
-
|
278
|
+
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
279
|
+
|
280
|
+
raise ValueError("unrecognized type")
|
281
|
+
|
282
|
+
|
283
|
+
def decode_video_base64(video_base64):
|
284
|
+
from PIL import Image
|
285
|
+
|
286
|
+
# Decode the base64 string
|
287
|
+
video_bytes = base64.b64decode(video_base64)
|
288
|
+
|
289
|
+
# Placeholder for the start indices of each PNG image
|
290
|
+
img_starts = []
|
291
|
+
|
292
|
+
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
|
293
|
+
|
294
|
+
assert frame_format in [
|
295
|
+
"PNG",
|
296
|
+
"JPEG",
|
297
|
+
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
|
298
|
+
|
299
|
+
if frame_format == "PNG":
|
300
|
+
# Find each PNG start signature to isolate images
|
301
|
+
i = 0
|
302
|
+
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
|
303
|
+
# Check if we found the start of a PNG file
|
304
|
+
if (
|
305
|
+
video_bytes[i] == 0x89
|
306
|
+
and video_bytes[i + 1] == 0x50
|
307
|
+
and video_bytes[i + 2] == 0x4E
|
308
|
+
and video_bytes[i + 3] == 0x47
|
309
|
+
and video_bytes[i + 4] == 0x0D
|
310
|
+
and video_bytes[i + 5] == 0x0A
|
311
|
+
and video_bytes[i + 6] == 0x1A
|
312
|
+
and video_bytes[i + 7] == 0x0A
|
313
|
+
):
|
314
|
+
img_starts.append(i)
|
315
|
+
i += 8 # Skip the PNG signature
|
316
|
+
else:
|
317
|
+
i += 1
|
318
|
+
else:
|
319
|
+
# Find each JPEG start (0xFFD8) to isolate images
|
320
|
+
i = 0
|
321
|
+
while (
|
322
|
+
i < len(video_bytes) - 1
|
323
|
+
): # Adjusted for the length of the JPEG SOI signature
|
324
|
+
# Check if we found the start of a JPEG file
|
325
|
+
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
|
326
|
+
img_starts.append(i)
|
327
|
+
# Move to the next byte to continue searching for the next image start
|
328
|
+
i += 2
|
329
|
+
else:
|
330
|
+
i += 1
|
331
|
+
|
332
|
+
frames = []
|
333
|
+
for start_idx in img_starts:
|
334
|
+
# Assuming each image is back-to-back, the end of one image is the start of another
|
335
|
+
# The last image goes until the end of the byte string
|
336
|
+
end_idx = (
|
337
|
+
img_starts[img_starts.index(start_idx) + 1]
|
338
|
+
if img_starts.index(start_idx) + 1 < len(img_starts)
|
339
|
+
else len(video_bytes)
|
340
|
+
)
|
341
|
+
img_bytes = video_bytes[start_idx:end_idx]
|
342
|
+
|
343
|
+
# Convert bytes to a PIL Image
|
344
|
+
img = Image.open(BytesIO(img_bytes))
|
345
|
+
|
346
|
+
# Convert PIL Image to a NumPy array
|
347
|
+
frame = np.array(img)
|
348
|
+
|
349
|
+
# Append the frame to the list of frames
|
350
|
+
frames.append(frame)
|
351
|
+
|
352
|
+
# Ensure there's at least one frame to avoid errors with np.stack
|
353
|
+
if frames:
|
354
|
+
return np.stack(frames, axis=0), img.size
|
355
|
+
else:
|
356
|
+
return np.array([]), (
|
357
|
+
0,
|
358
|
+
0,
|
359
|
+
) # Return an empty array and size tuple if no frames were found
|
242
360
|
|
243
361
|
|
244
362
|
def load_image(image_file):
|
245
363
|
from PIL import Image
|
246
364
|
|
247
|
-
image = None
|
365
|
+
image = image_size = None
|
248
366
|
|
249
367
|
if image_file.startswith("http://") or image_file.startswith("https://"):
|
250
368
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
@@ -255,7 +373,54 @@ def load_image(image_file):
|
|
255
373
|
elif image_file.startswith("data:"):
|
256
374
|
image_file = image_file.split(",")[1]
|
257
375
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
376
|
+
elif image_file.startswith("video:"):
|
377
|
+
image_file = image_file.replace("video:", "")
|
378
|
+
image, image_size = decode_video_base64(image_file)
|
258
379
|
else:
|
259
380
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
260
381
|
|
261
|
-
return image
|
382
|
+
return image, image_size
|
383
|
+
|
384
|
+
|
385
|
+
def assert_pkg_version(pkg: str, min_version: str):
|
386
|
+
try:
|
387
|
+
installed_version = version(pkg)
|
388
|
+
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
|
389
|
+
raise Exception(
|
390
|
+
f"{pkg} is installed with version {installed_version} which "
|
391
|
+
f"is less than the minimum required version {min_version}"
|
392
|
+
)
|
393
|
+
except PackageNotFoundError:
|
394
|
+
raise Exception(
|
395
|
+
f"{pkg} with minimum required version {min_version} is not installed"
|
396
|
+
)
|
397
|
+
|
398
|
+
|
399
|
+
API_KEY_HEADER_NAME = "X-API-Key"
|
400
|
+
|
401
|
+
|
402
|
+
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
403
|
+
def __init__(self, app, api_key: str):
|
404
|
+
super().__init__(app)
|
405
|
+
self.api_key = api_key
|
406
|
+
|
407
|
+
async def dispatch(self, request, call_next):
|
408
|
+
# extract API key from the request headers
|
409
|
+
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
410
|
+
if not api_key_header or api_key_header != self.api_key:
|
411
|
+
return JSONResponse(
|
412
|
+
status_code=403,
|
413
|
+
content={"detail": "Invalid API Key"},
|
414
|
+
)
|
415
|
+
response = await call_next(request)
|
416
|
+
return response
|
417
|
+
|
418
|
+
|
419
|
+
# FIXME: Remove this once we drop support for pydantic 1.x
|
420
|
+
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
421
|
+
|
422
|
+
|
423
|
+
def jsonify_pydantic_model(obj: BaseModel):
|
424
|
+
if IS_PYDANTIC_1:
|
425
|
+
return obj.json(ensure_ascii=False)
|
426
|
+
return obj.model_dump_json()
|