sglang 0.3.6.post1__py3-none-any.whl → 0.3.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +4 -8
- sglang/bench_one_batch_server.py +6 -5
- sglang/check_env.py +7 -1
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +2 -4
- sglang/srt/configs/model_config.py +2 -6
- sglang/srt/layers/attention/flashinfer_backend.py +3 -3
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -6
- sglang/srt/managers/image_processor.py +7 -10
- sglang/srt/managers/io_struct.py +0 -10
- sglang/srt/managers/schedule_batch.py +51 -13
- sglang/srt/managers/scheduler.py +41 -29
- sglang/srt/managers/session_controller.py +15 -7
- sglang/srt/managers/tokenizer_manager.py +4 -33
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -2
- sglang/srt/models/grok.py +11 -48
- sglang/srt/models/llava.py +16 -9
- sglang/srt/models/olmo2.py +392 -0
- sglang/srt/models/qwen2_vl.py +10 -3
- sglang/srt/openai_api/adapter.py +1 -1
- sglang/srt/server.py +48 -45
- sglang/srt/server_args.py +1 -1
- sglang/srt/utils.py +22 -24
- sglang/test/test_utils.py +21 -8
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/METADATA +4 -2
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/RECORD +34 -36
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/top_level.txt +0 -0
@@ -14,20 +14,20 @@ import argparse
|
|
14
14
|
import dataclasses
|
15
15
|
import json
|
16
16
|
import logging
|
17
|
+
import os
|
17
18
|
import random
|
18
19
|
import time
|
19
20
|
from typing import Dict, List, Optional, Tuple
|
20
21
|
|
21
22
|
import numpy as np
|
22
23
|
|
23
|
-
from sglang.api import Engine
|
24
24
|
from sglang.bench_serving import (
|
25
25
|
get_dataset,
|
26
26
|
get_tokenizer,
|
27
27
|
sample_random_requests,
|
28
28
|
set_ulimit,
|
29
29
|
)
|
30
|
-
from sglang.srt.server import Runtime
|
30
|
+
from sglang.srt.server import Engine, Runtime
|
31
31
|
from sglang.srt.server_args import ServerArgs
|
32
32
|
|
33
33
|
|
@@ -52,6 +52,7 @@ class BenchArgs:
|
|
52
52
|
seed: int = 1
|
53
53
|
skip_warmup: bool = False
|
54
54
|
do_not_exit: bool = False
|
55
|
+
profile: bool = False
|
55
56
|
|
56
57
|
@staticmethod
|
57
58
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -156,6 +157,12 @@ class BenchArgs:
|
|
156
157
|
action="store_true",
|
157
158
|
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
158
159
|
)
|
160
|
+
parser.add_argument(
|
161
|
+
"--profile",
|
162
|
+
action="store_true",
|
163
|
+
help="Use Torch Profiler. The endpoint must be launched with "
|
164
|
+
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
165
|
+
)
|
159
166
|
|
160
167
|
@classmethod
|
161
168
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -169,6 +176,7 @@ def throughput_test_once(
|
|
169
176
|
reqs: List[Tuple[str, int, int]],
|
170
177
|
ignore_eos: bool,
|
171
178
|
extra_request_body: Dict,
|
179
|
+
profile: bool,
|
172
180
|
):
|
173
181
|
measurement_results = {
|
174
182
|
"backend": backend_name,
|
@@ -194,7 +202,15 @@ def throughput_test_once(
|
|
194
202
|
]
|
195
203
|
|
196
204
|
st = time.perf_counter()
|
205
|
+
if profile:
|
206
|
+
backend.start_profile()
|
207
|
+
|
197
208
|
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
|
209
|
+
|
210
|
+
if profile:
|
211
|
+
backend.stop_profile()
|
212
|
+
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
|
213
|
+
|
198
214
|
latency = time.perf_counter() - st
|
199
215
|
|
200
216
|
if backend_name == "runtime":
|
@@ -221,6 +237,41 @@ def throughput_test_once(
|
|
221
237
|
return measurement_results
|
222
238
|
|
223
239
|
|
240
|
+
def monitor_trace_file(directory, interval=1):
|
241
|
+
|
242
|
+
print(f"Monitoring {directory} for new trace files...")
|
243
|
+
|
244
|
+
known_files = set(os.listdir(directory))
|
245
|
+
|
246
|
+
while True:
|
247
|
+
flag = False
|
248
|
+
time.sleep(interval)
|
249
|
+
current_files = set(os.listdir(directory))
|
250
|
+
|
251
|
+
new_files = current_files - known_files
|
252
|
+
for new_file in new_files:
|
253
|
+
new_file_path = os.path.join(directory, new_file)
|
254
|
+
print(f"New file detected: {new_file}")
|
255
|
+
|
256
|
+
previous_size = 0
|
257
|
+
while True:
|
258
|
+
try:
|
259
|
+
current_size = os.path.getsize(new_file_path)
|
260
|
+
except FileNotFoundError:
|
261
|
+
print(f"File {new_file} is no longer accessible.")
|
262
|
+
break
|
263
|
+
|
264
|
+
if current_size > previous_size:
|
265
|
+
previous_size = current_size
|
266
|
+
else:
|
267
|
+
flag = True
|
268
|
+
break
|
269
|
+
|
270
|
+
time.sleep(interval)
|
271
|
+
if flag:
|
272
|
+
break
|
273
|
+
|
274
|
+
|
224
275
|
def throughput_test(
|
225
276
|
server_args: ServerArgs,
|
226
277
|
bench_args: BenchArgs,
|
@@ -268,6 +319,7 @@ def throughput_test(
|
|
268
319
|
reqs=warmup_requests,
|
269
320
|
ignore_eos=not bench_args.disable_ignore_eos,
|
270
321
|
extra_request_body=extra_request_body,
|
322
|
+
profile=False,
|
271
323
|
)
|
272
324
|
|
273
325
|
logging.info("\nBenchmark...")
|
@@ -277,6 +329,7 @@ def throughput_test(
|
|
277
329
|
reqs=input_requests,
|
278
330
|
ignore_eos=not bench_args.disable_ignore_eos,
|
279
331
|
extra_request_body=extra_request_body,
|
332
|
+
profile=bench_args.profile,
|
280
333
|
)
|
281
334
|
|
282
335
|
if bench_args.result_filename:
|
sglang/bench_one_batch.py
CHANGED
@@ -47,6 +47,7 @@ import itertools
|
|
47
47
|
import json
|
48
48
|
import logging
|
49
49
|
import multiprocessing
|
50
|
+
import os
|
50
51
|
import time
|
51
52
|
from typing import Tuple
|
52
53
|
|
@@ -62,11 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
|
|
62
63
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
63
64
|
from sglang.srt.server import _set_envs_and_config
|
64
65
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
65
|
-
from sglang.srt.utils import
|
66
|
-
configure_logger,
|
67
|
-
kill_child_process,
|
68
|
-
suppress_other_loggers,
|
69
|
-
)
|
66
|
+
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
|
70
67
|
|
71
68
|
|
72
69
|
@dataclasses.dataclass
|
@@ -466,7 +463,6 @@ if __name__ == "__main__":
|
|
466
463
|
|
467
464
|
try:
|
468
465
|
main(server_args, bench_args)
|
469
|
-
except Exception as e:
|
470
|
-
raise e
|
471
466
|
finally:
|
472
|
-
|
467
|
+
if server_args.tp_size != 1:
|
468
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
sglang/bench_one_batch_server.py
CHANGED
@@ -5,9 +5,9 @@ This script launches a server and uses the HTTP interface.
|
|
5
5
|
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
6
6
|
|
7
7
|
Usage:
|
8
|
-
python3 -m sglang.
|
8
|
+
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
9
9
|
|
10
|
-
python3 -m sglang.
|
10
|
+
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
11
11
|
"""
|
12
12
|
|
13
13
|
import argparse
|
@@ -15,6 +15,7 @@ import dataclasses
|
|
15
15
|
import itertools
|
16
16
|
import json
|
17
17
|
import multiprocessing
|
18
|
+
import os
|
18
19
|
import time
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -23,7 +24,7 @@ import requests
|
|
23
24
|
|
24
25
|
from sglang.srt.server import launch_server
|
25
26
|
from sglang.srt.server_args import ServerArgs
|
26
|
-
from sglang.srt.utils import
|
27
|
+
from sglang.srt.utils import kill_process_tree
|
27
28
|
|
28
29
|
|
29
30
|
@dataclasses.dataclass
|
@@ -69,7 +70,7 @@ def launch_server_internal(server_args):
|
|
69
70
|
except Exception as e:
|
70
71
|
raise e
|
71
72
|
finally:
|
72
|
-
|
73
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
73
74
|
|
74
75
|
|
75
76
|
def launch_server_process(server_args: ServerArgs):
|
@@ -175,7 +176,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
175
176
|
)
|
176
177
|
finally:
|
177
178
|
if proc:
|
178
|
-
|
179
|
+
kill_process_tree(proc.pid)
|
179
180
|
|
180
181
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
181
182
|
|
sglang/check_env.py
CHANGED
@@ -22,18 +22,24 @@ PACKAGE_LIST = [
|
|
22
22
|
"hf_transfer",
|
23
23
|
"huggingface_hub",
|
24
24
|
"interegular",
|
25
|
+
"modelscope",
|
26
|
+
"orjson",
|
27
|
+
"outlines",
|
28
|
+
"packaging",
|
25
29
|
"psutil",
|
26
30
|
"pydantic",
|
27
31
|
"multipart",
|
28
32
|
"zmq",
|
33
|
+
"torchao",
|
29
34
|
"uvicorn",
|
30
35
|
"uvloop",
|
31
36
|
"vllm",
|
32
|
-
"
|
37
|
+
"xgrammar",
|
33
38
|
"openai",
|
34
39
|
"tiktoken",
|
35
40
|
"anthropic",
|
36
41
|
"litellm",
|
42
|
+
"decord",
|
37
43
|
]
|
38
44
|
|
39
45
|
|
sglang/lang/tracer.py
CHANGED
sglang/launch_server.py
CHANGED
@@ -5,14 +5,12 @@ import sys
|
|
5
5
|
|
6
6
|
from sglang.srt.server import launch_server
|
7
7
|
from sglang.srt.server_args import prepare_server_args
|
8
|
-
from sglang.srt.utils import
|
8
|
+
from sglang.srt.utils import kill_process_tree
|
9
9
|
|
10
10
|
if __name__ == "__main__":
|
11
11
|
server_args = prepare_server_args(sys.argv[1:])
|
12
12
|
|
13
13
|
try:
|
14
14
|
launch_server(server_args)
|
15
|
-
except Exception as e:
|
16
|
-
raise e
|
17
15
|
finally:
|
18
|
-
|
16
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
@@ -14,13 +14,13 @@
|
|
14
14
|
|
15
15
|
import json
|
16
16
|
import logging
|
17
|
-
import os
|
18
17
|
from enum import IntEnum, auto
|
19
18
|
from typing import List, Optional
|
20
19
|
|
21
20
|
from transformers import PretrainedConfig
|
22
21
|
|
23
22
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
23
|
+
from sglang.srt.utils import get_bool_env_var
|
24
24
|
|
25
25
|
logger = logging.getLogger(__name__)
|
26
26
|
|
@@ -59,13 +59,9 @@ class ModelConfig:
|
|
59
59
|
|
60
60
|
# Derive context length
|
61
61
|
derived_context_len = get_context_length(self.hf_text_config)
|
62
|
-
allow_long_context = os.environ.get(
|
63
|
-
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
64
|
-
)
|
65
|
-
|
66
62
|
if context_length is not None:
|
67
63
|
if context_length > derived_context_len:
|
68
|
-
if
|
64
|
+
if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"):
|
69
65
|
logger.warning(
|
70
66
|
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
71
67
|
f"This may lead to incorrect model outputs or CUDA errors."
|
@@ -18,7 +18,7 @@ import triton.language as tl
|
|
18
18
|
from sglang.global_config import global_config
|
19
19
|
from sglang.srt.layers.attention import AttentionBackend
|
20
20
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
21
|
-
from sglang.srt.utils import is_flashinfer_available
|
21
|
+
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
|
22
22
|
|
23
23
|
if TYPE_CHECKING:
|
24
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -47,8 +47,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
47
47
|
|
48
48
|
# Parse constants
|
49
49
|
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
|
50
|
-
self.decode_use_tensor_cores = (
|
51
|
-
|
50
|
+
self.decode_use_tensor_cores = get_bool_env_var(
|
51
|
+
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
|
52
52
|
)
|
53
53
|
else:
|
54
54
|
if not _grouped_size_compiled_for_decode_kernels(
|
sglang/srt/layers/sampler.py
CHANGED
@@ -74,7 +74,7 @@ class Sampler(nn.Module):
|
|
74
74
|
filter_apply_order="joint",
|
75
75
|
)
|
76
76
|
|
77
|
-
if not torch.all(success):
|
77
|
+
if self.use_nan_detectioin and not torch.all(success):
|
78
78
|
logger.warning("Detected errors during sampling!")
|
79
79
|
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
80
80
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import multiprocessing as mp
|
18
|
+
import signal
|
18
19
|
import threading
|
19
20
|
from enum import Enum, auto
|
20
21
|
|
22
|
+
import psutil
|
21
23
|
import zmq
|
22
24
|
|
23
25
|
from sglang.srt.managers.io_struct import (
|
@@ -26,13 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|
26
28
|
)
|
27
29
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
28
30
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
29
|
-
from sglang.srt.utils import
|
30
|
-
bind_port,
|
31
|
-
configure_logger,
|
32
|
-
get_zmq_socket,
|
33
|
-
kill_parent_process,
|
34
|
-
suppress_other_loggers,
|
35
|
-
)
|
31
|
+
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
36
32
|
from sglang.utils import get_exception_traceback
|
37
33
|
|
38
34
|
logger = logging.getLogger(__name__)
|
@@ -235,7 +231,7 @@ def run_data_parallel_controller_process(
|
|
235
231
|
pipe_writer,
|
236
232
|
):
|
237
233
|
configure_logger(server_args)
|
238
|
-
|
234
|
+
parent_process = psutil.Process().parent()
|
239
235
|
|
240
236
|
try:
|
241
237
|
controller = DataParallelController(server_args, port_args)
|
@@ -244,6 +240,6 @@ def run_data_parallel_controller_process(
|
|
244
240
|
)
|
245
241
|
controller.event_loop()
|
246
242
|
except Exception:
|
247
|
-
|
248
|
-
logger.error(
|
249
|
-
|
243
|
+
traceback = get_exception_traceback()
|
244
|
+
logger.error(f"DataParallelController hit an exception: {traceback}")
|
245
|
+
parent_process.send_signal(signal.SIGQUIT)
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import dataclasses
|
17
17
|
import logging
|
18
|
+
import signal
|
18
19
|
from collections import OrderedDict
|
19
20
|
from typing import List, Union
|
20
21
|
|
22
|
+
import psutil
|
21
23
|
import zmq
|
22
24
|
|
23
25
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
@@ -25,12 +27,10 @@ from sglang.srt.managers.io_struct import (
|
|
25
27
|
BatchEmbeddingOut,
|
26
28
|
BatchStrOut,
|
27
29
|
BatchTokenIDOut,
|
28
|
-
GetMemPoolSizeReqOutput,
|
29
|
-
UpdateWeightReqOutput,
|
30
30
|
)
|
31
31
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
32
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
33
|
-
from sglang.srt.utils import configure_logger, get_zmq_socket
|
33
|
+
from sglang.srt.utils import configure_logger, get_zmq_socket
|
34
34
|
from sglang.utils import find_printable_text, get_exception_traceback
|
35
35
|
|
36
36
|
logger = logging.getLogger(__name__)
|
@@ -195,11 +195,12 @@ def run_detokenizer_process(
|
|
195
195
|
port_args: PortArgs,
|
196
196
|
):
|
197
197
|
configure_logger(server_args)
|
198
|
+
parent_process = psutil.Process().parent()
|
198
199
|
|
199
200
|
try:
|
200
201
|
manager = DetokenizerManager(server_args, port_args)
|
201
202
|
manager.event_loop()
|
202
203
|
except Exception:
|
203
|
-
|
204
|
-
logger.error(
|
205
|
-
|
204
|
+
traceback = get_exception_traceback()
|
205
|
+
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
206
|
+
parent_process.send_signal(signal.SIGQUIT)
|
@@ -131,6 +131,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|
131
131
|
if not image_data:
|
132
132
|
return None
|
133
133
|
|
134
|
+
modalities = request_obj.modalities or ["image"]
|
134
135
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
135
136
|
grid_pinpoints = (
|
136
137
|
self.hf_config.image_grid_pinpoints
|
@@ -139,9 +140,12 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|
139
140
|
else None
|
140
141
|
)
|
141
142
|
|
143
|
+
if isinstance(image_data, str):
|
144
|
+
image_data = [image_data]
|
145
|
+
|
142
146
|
if isinstance(image_data, list) and len(image_data) > 0:
|
143
|
-
|
144
|
-
|
147
|
+
if "multi-images" in modalities or "video" in modalities:
|
148
|
+
# Multiple images
|
145
149
|
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
146
150
|
pixel_values, image_hashes, image_sizes = [], [], []
|
147
151
|
res = []
|
@@ -166,13 +170,6 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|
166
170
|
)
|
167
171
|
image_hashes = [image_hash]
|
168
172
|
image_sizes = [image_size]
|
169
|
-
elif isinstance(image_data, str):
|
170
|
-
# A single image
|
171
|
-
pixel_values, image_hash, image_size = await self._process_single_image(
|
172
|
-
image_data, aspect_ratio, grid_pinpoints
|
173
|
-
)
|
174
|
-
image_hashes = [image_hash]
|
175
|
-
image_sizes = [image_size]
|
176
173
|
else:
|
177
174
|
raise ValueError(f"Invalid image data: {image_data}")
|
178
175
|
|
@@ -341,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|
341
338
|
"pixel_values": pixel_values,
|
342
339
|
"image_hashes": image_hashes,
|
343
340
|
"image_sizes": image_sizes,
|
344
|
-
"modalities": request_obj.modalities,
|
341
|
+
"modalities": request_obj.modalities or ["image"],
|
345
342
|
"image_grid_thws": image_grid_thws,
|
346
343
|
}
|
347
344
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -376,16 +376,6 @@ class ProfileReq(Enum):
|
|
376
376
|
STOP_PROFILE = 2
|
377
377
|
|
378
378
|
|
379
|
-
@dataclass
|
380
|
-
class GetMemPoolSizeReq:
|
381
|
-
pass
|
382
|
-
|
383
|
-
|
384
|
-
@dataclass
|
385
|
-
class GetMemPoolSizeReqOutput:
|
386
|
-
size: int
|
387
|
-
|
388
|
-
|
389
379
|
@dataclass
|
390
380
|
class OpenSessionReqInput:
|
391
381
|
capacity_of_str_len: int
|
@@ -31,6 +31,7 @@ import dataclasses
|
|
31
31
|
import logging
|
32
32
|
from typing import List, Optional, Tuple, Union
|
33
33
|
|
34
|
+
import numpy as np
|
34
35
|
import torch
|
35
36
|
import triton
|
36
37
|
import triton.language as tl
|
@@ -123,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
|
|
123
124
|
class ImageInputs:
|
124
125
|
"""The image related inputs."""
|
125
126
|
|
126
|
-
pixel_values: torch.Tensor
|
127
|
+
pixel_values: Union[torch.Tensor, np.array]
|
127
128
|
image_hashes: Optional[list] = None
|
128
129
|
image_sizes: Optional[list] = None
|
129
130
|
image_offsets: Optional[list] = None
|
@@ -131,7 +132,7 @@ class ImageInputs:
|
|
131
132
|
modalities: Optional[list] = None
|
132
133
|
num_image_tokens: Optional[int] = None
|
133
134
|
|
134
|
-
|
135
|
+
# Llava related
|
135
136
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
136
137
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
137
138
|
|
@@ -140,19 +141,17 @@ class ImageInputs:
|
|
140
141
|
mrope_position_delta: Optional[torch.Tensor] = None
|
141
142
|
|
142
143
|
@staticmethod
|
143
|
-
def from_dict(obj
|
144
|
-
# Use image hash as fake token_ids, which is then used for prefix matching
|
144
|
+
def from_dict(obj: dict):
|
145
145
|
ret = ImageInputs(
|
146
146
|
pixel_values=obj["pixel_values"],
|
147
|
-
image_hashes=
|
147
|
+
image_hashes=obj["image_hashes"],
|
148
148
|
)
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
]
|
149
|
+
|
150
|
+
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
151
|
+
# Please note that if the `input_ids` is later used in the model forward,
|
152
|
+
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
153
|
+
# errors in cuda kernels. See also llava.py for example.
|
154
|
+
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
|
156
155
|
|
157
156
|
optional_args = [
|
158
157
|
"image_sizes",
|
@@ -167,6 +166,29 @@ class ImageInputs:
|
|
167
166
|
|
168
167
|
return ret
|
169
168
|
|
169
|
+
def merge(self, other):
|
170
|
+
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
171
|
+
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
172
|
+
|
173
|
+
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
174
|
+
# Please note that if the `input_ids` is later used in the model forward,
|
175
|
+
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
176
|
+
# errors in cuda kernels. See also llava.py for example.
|
177
|
+
self.image_hashes += other.image_hashes
|
178
|
+
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
|
179
|
+
|
180
|
+
optional_args = [
|
181
|
+
"image_sizes",
|
182
|
+
"image_offsets",
|
183
|
+
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
184
|
+
"aspect_ratio_ids",
|
185
|
+
"aspect_ratio_mask",
|
186
|
+
"image_grid_thws",
|
187
|
+
]
|
188
|
+
for arg in optional_args:
|
189
|
+
if getattr(self, arg, None) is not None:
|
190
|
+
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
|
191
|
+
|
170
192
|
|
171
193
|
class Req:
|
172
194
|
"""The input and output status of a request."""
|
@@ -177,6 +199,7 @@ class Req:
|
|
177
199
|
origin_input_text: str,
|
178
200
|
origin_input_ids: Tuple[int],
|
179
201
|
sampling_params: SamplingParams,
|
202
|
+
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
180
203
|
lora_path: Optional[str] = None,
|
181
204
|
input_embeds: Optional[List[List[float]]] = None,
|
182
205
|
session_id: Optional[str] = None,
|
@@ -184,7 +207,11 @@ class Req:
|
|
184
207
|
# Input and output info
|
185
208
|
self.rid = rid
|
186
209
|
self.origin_input_text = origin_input_text
|
187
|
-
self.origin_input_ids_unpadded =
|
210
|
+
self.origin_input_ids_unpadded = (
|
211
|
+
origin_input_ids_unpadded
|
212
|
+
if origin_input_ids_unpadded
|
213
|
+
else origin_input_ids # Before image padding
|
214
|
+
)
|
188
215
|
self.origin_input_ids = origin_input_ids
|
189
216
|
self.output_ids = [] # Each decode stage's output ids
|
190
217
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
@@ -201,6 +228,7 @@ class Req:
|
|
201
228
|
self.tokenizer = None
|
202
229
|
self.finished_reason = None
|
203
230
|
self.stream = False
|
231
|
+
self.to_abort = False
|
204
232
|
|
205
233
|
# For incremental decoding
|
206
234
|
# ----- | --------- read_ids -------|
|
@@ -260,6 +288,12 @@ class Req:
|
|
260
288
|
# The number of cached tokens, that were already cached in the KV cache
|
261
289
|
self.cached_tokens = 0
|
262
290
|
|
291
|
+
def extend_image_inputs(self, image_inputs):
|
292
|
+
if self.image_inputs is None:
|
293
|
+
self.image_inputs = image_inputs
|
294
|
+
else:
|
295
|
+
self.image_inputs.merge(image_inputs)
|
296
|
+
|
263
297
|
# whether request reached finished condition
|
264
298
|
def finished(self) -> bool:
|
265
299
|
return self.finished_reason is not None
|
@@ -332,6 +366,10 @@ class Req:
|
|
332
366
|
if self.finished():
|
333
367
|
return
|
334
368
|
|
369
|
+
if self.to_abort:
|
370
|
+
self.finished_reason = FINISH_ABORT()
|
371
|
+
return
|
372
|
+
|
335
373
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
336
374
|
self.finished_reason = FINISH_LENGTH(
|
337
375
|
length=self.sampling_params.max_new_tokens
|