sglang 0.3.6.post2__py3-none-any.whl → 0.3.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +3 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/launch_server.py +3 -2
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +0 -10
- sglang/srt/managers/schedule_batch.py +24 -22
- sglang/srt/managers/scheduler.py +35 -26
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +4 -33
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -2
- sglang/srt/models/grok.py +11 -48
- sglang/srt/models/llava.py +9 -8
- sglang/srt/models/olmo2.py +392 -0
- sglang/srt/models/qwen2_vl.py +10 -3
- sglang/srt/openai_api/adapter.py +1 -1
- sglang/srt/server.py +46 -44
- sglang/srt/server_args.py +1 -1
- sglang/srt/utils.py +8 -20
- sglang/test/test_utils.py +20 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.3.6.post3.dist-info}/METADATA +2 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.3.6.post3.dist-info}/RECORD +29 -31
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- {sglang-0.3.6.post2.dist-info → sglang-0.3.6.post3.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.3.6.post3.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.3.6.post3.dist-info}/top_level.txt +0 -0
@@ -14,20 +14,20 @@ import argparse
|
|
14
14
|
import dataclasses
|
15
15
|
import json
|
16
16
|
import logging
|
17
|
+
import os
|
17
18
|
import random
|
18
19
|
import time
|
19
20
|
from typing import Dict, List, Optional, Tuple
|
20
21
|
|
21
22
|
import numpy as np
|
22
23
|
|
23
|
-
from sglang.api import Engine
|
24
24
|
from sglang.bench_serving import (
|
25
25
|
get_dataset,
|
26
26
|
get_tokenizer,
|
27
27
|
sample_random_requests,
|
28
28
|
set_ulimit,
|
29
29
|
)
|
30
|
-
from sglang.srt.server import Runtime
|
30
|
+
from sglang.srt.server import Engine, Runtime
|
31
31
|
from sglang.srt.server_args import ServerArgs
|
32
32
|
|
33
33
|
|
@@ -52,6 +52,7 @@ class BenchArgs:
|
|
52
52
|
seed: int = 1
|
53
53
|
skip_warmup: bool = False
|
54
54
|
do_not_exit: bool = False
|
55
|
+
profile: bool = False
|
55
56
|
|
56
57
|
@staticmethod
|
57
58
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -156,6 +157,12 @@ class BenchArgs:
|
|
156
157
|
action="store_true",
|
157
158
|
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
158
159
|
)
|
160
|
+
parser.add_argument(
|
161
|
+
"--profile",
|
162
|
+
action="store_true",
|
163
|
+
help="Use Torch Profiler. The endpoint must be launched with "
|
164
|
+
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
165
|
+
)
|
159
166
|
|
160
167
|
@classmethod
|
161
168
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -169,6 +176,7 @@ def throughput_test_once(
|
|
169
176
|
reqs: List[Tuple[str, int, int]],
|
170
177
|
ignore_eos: bool,
|
171
178
|
extra_request_body: Dict,
|
179
|
+
profile: bool,
|
172
180
|
):
|
173
181
|
measurement_results = {
|
174
182
|
"backend": backend_name,
|
@@ -194,7 +202,15 @@ def throughput_test_once(
|
|
194
202
|
]
|
195
203
|
|
196
204
|
st = time.perf_counter()
|
205
|
+
if profile:
|
206
|
+
backend.start_profile()
|
207
|
+
|
197
208
|
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
|
209
|
+
|
210
|
+
if profile:
|
211
|
+
backend.stop_profile()
|
212
|
+
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
|
213
|
+
|
198
214
|
latency = time.perf_counter() - st
|
199
215
|
|
200
216
|
if backend_name == "runtime":
|
@@ -221,6 +237,41 @@ def throughput_test_once(
|
|
221
237
|
return measurement_results
|
222
238
|
|
223
239
|
|
240
|
+
def monitor_trace_file(directory, interval=1):
|
241
|
+
|
242
|
+
print(f"Monitoring {directory} for new trace files...")
|
243
|
+
|
244
|
+
known_files = set(os.listdir(directory))
|
245
|
+
|
246
|
+
while True:
|
247
|
+
flag = False
|
248
|
+
time.sleep(interval)
|
249
|
+
current_files = set(os.listdir(directory))
|
250
|
+
|
251
|
+
new_files = current_files - known_files
|
252
|
+
for new_file in new_files:
|
253
|
+
new_file_path = os.path.join(directory, new_file)
|
254
|
+
print(f"New file detected: {new_file}")
|
255
|
+
|
256
|
+
previous_size = 0
|
257
|
+
while True:
|
258
|
+
try:
|
259
|
+
current_size = os.path.getsize(new_file_path)
|
260
|
+
except FileNotFoundError:
|
261
|
+
print(f"File {new_file} is no longer accessible.")
|
262
|
+
break
|
263
|
+
|
264
|
+
if current_size > previous_size:
|
265
|
+
previous_size = current_size
|
266
|
+
else:
|
267
|
+
flag = True
|
268
|
+
break
|
269
|
+
|
270
|
+
time.sleep(interval)
|
271
|
+
if flag:
|
272
|
+
break
|
273
|
+
|
274
|
+
|
224
275
|
def throughput_test(
|
225
276
|
server_args: ServerArgs,
|
226
277
|
bench_args: BenchArgs,
|
@@ -268,6 +319,7 @@ def throughput_test(
|
|
268
319
|
reqs=warmup_requests,
|
269
320
|
ignore_eos=not bench_args.disable_ignore_eos,
|
270
321
|
extra_request_body=extra_request_body,
|
322
|
+
profile=False,
|
271
323
|
)
|
272
324
|
|
273
325
|
logging.info("\nBenchmark...")
|
@@ -277,6 +329,7 @@ def throughput_test(
|
|
277
329
|
reqs=input_requests,
|
278
330
|
ignore_eos=not bench_args.disable_ignore_eos,
|
279
331
|
extra_request_body=extra_request_body,
|
332
|
+
profile=bench_args.profile,
|
280
333
|
)
|
281
334
|
|
282
335
|
if bench_args.result_filename:
|
sglang/bench_one_batch.py
CHANGED
@@ -47,6 +47,7 @@ import itertools
|
|
47
47
|
import json
|
48
48
|
import logging
|
49
49
|
import multiprocessing
|
50
|
+
import os
|
50
51
|
import time
|
51
52
|
from typing import Tuple
|
52
53
|
|
@@ -62,11 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
|
|
62
63
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
63
64
|
from sglang.srt.server import _set_envs_and_config
|
64
65
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
65
|
-
from sglang.srt.utils import
|
66
|
-
configure_logger,
|
67
|
-
kill_child_process,
|
68
|
-
suppress_other_loggers,
|
69
|
-
)
|
66
|
+
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
|
70
67
|
|
71
68
|
|
72
69
|
@dataclasses.dataclass
|
@@ -468,4 +465,4 @@ if __name__ == "__main__":
|
|
468
465
|
main(server_args, bench_args)
|
469
466
|
finally:
|
470
467
|
if server_args.tp_size != 1:
|
471
|
-
|
468
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
sglang/bench_one_batch_server.py
CHANGED
@@ -15,6 +15,7 @@ import dataclasses
|
|
15
15
|
import itertools
|
16
16
|
import json
|
17
17
|
import multiprocessing
|
18
|
+
import os
|
18
19
|
import time
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -23,7 +24,7 @@ import requests
|
|
23
24
|
|
24
25
|
from sglang.srt.server import launch_server
|
25
26
|
from sglang.srt.server_args import ServerArgs
|
26
|
-
from sglang.srt.utils import
|
27
|
+
from sglang.srt.utils import kill_process_tree
|
27
28
|
|
28
29
|
|
29
30
|
@dataclasses.dataclass
|
@@ -69,7 +70,7 @@ def launch_server_internal(server_args):
|
|
69
70
|
except Exception as e:
|
70
71
|
raise e
|
71
72
|
finally:
|
72
|
-
|
73
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
73
74
|
|
74
75
|
|
75
76
|
def launch_server_process(server_args: ServerArgs):
|
@@ -175,7 +176,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
175
176
|
)
|
176
177
|
finally:
|
177
178
|
if proc:
|
178
|
-
|
179
|
+
kill_process_tree(proc.pid)
|
179
180
|
|
180
181
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
181
182
|
|
sglang/launch_server.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
"""Launch the inference server."""
|
2
2
|
|
3
|
+
import os
|
3
4
|
import sys
|
4
5
|
|
5
6
|
from sglang.srt.server import launch_server
|
6
7
|
from sglang.srt.server_args import prepare_server_args
|
7
|
-
from sglang.srt.utils import
|
8
|
+
from sglang.srt.utils import kill_process_tree
|
8
9
|
|
9
10
|
if __name__ == "__main__":
|
10
11
|
server_args = prepare_server_args(sys.argv[1:])
|
@@ -12,4 +13,4 @@ if __name__ == "__main__":
|
|
12
13
|
try:
|
13
14
|
launch_server(server_args)
|
14
15
|
finally:
|
15
|
-
|
16
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import multiprocessing as mp
|
18
|
+
import signal
|
18
19
|
import threading
|
19
20
|
from enum import Enum, auto
|
20
21
|
|
22
|
+
import psutil
|
21
23
|
import zmq
|
22
24
|
|
23
25
|
from sglang.srt.managers.io_struct import (
|
@@ -26,13 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|
26
28
|
)
|
27
29
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
28
30
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
29
|
-
from sglang.srt.utils import
|
30
|
-
bind_port,
|
31
|
-
configure_logger,
|
32
|
-
get_zmq_socket,
|
33
|
-
kill_parent_process,
|
34
|
-
suppress_other_loggers,
|
35
|
-
)
|
31
|
+
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
36
32
|
from sglang.utils import get_exception_traceback
|
37
33
|
|
38
34
|
logger = logging.getLogger(__name__)
|
@@ -235,7 +231,7 @@ def run_data_parallel_controller_process(
|
|
235
231
|
pipe_writer,
|
236
232
|
):
|
237
233
|
configure_logger(server_args)
|
238
|
-
|
234
|
+
parent_process = psutil.Process().parent()
|
239
235
|
|
240
236
|
try:
|
241
237
|
controller = DataParallelController(server_args, port_args)
|
@@ -244,6 +240,6 @@ def run_data_parallel_controller_process(
|
|
244
240
|
)
|
245
241
|
controller.event_loop()
|
246
242
|
except Exception:
|
247
|
-
|
248
|
-
logger.error(
|
249
|
-
|
243
|
+
traceback = get_exception_traceback()
|
244
|
+
logger.error(f"DataParallelController hit an exception: {traceback}")
|
245
|
+
parent_process.send_signal(signal.SIGQUIT)
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import dataclasses
|
17
17
|
import logging
|
18
|
+
import signal
|
18
19
|
from collections import OrderedDict
|
19
20
|
from typing import List, Union
|
20
21
|
|
22
|
+
import psutil
|
21
23
|
import zmq
|
22
24
|
|
23
25
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
@@ -28,7 +30,7 @@ from sglang.srt.managers.io_struct import (
|
|
28
30
|
)
|
29
31
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
30
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
-
from sglang.srt.utils import configure_logger, get_zmq_socket
|
33
|
+
from sglang.srt.utils import configure_logger, get_zmq_socket
|
32
34
|
from sglang.utils import find_printable_text, get_exception_traceback
|
33
35
|
|
34
36
|
logger = logging.getLogger(__name__)
|
@@ -193,11 +195,12 @@ def run_detokenizer_process(
|
|
193
195
|
port_args: PortArgs,
|
194
196
|
):
|
195
197
|
configure_logger(server_args)
|
198
|
+
parent_process = psutil.Process().parent()
|
196
199
|
|
197
200
|
try:
|
198
201
|
manager = DetokenizerManager(server_args, port_args)
|
199
202
|
manager.event_loop()
|
200
203
|
except Exception:
|
201
|
-
|
202
|
-
logger.error(
|
203
|
-
|
204
|
+
traceback = get_exception_traceback()
|
205
|
+
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
206
|
+
parent_process.send_signal(signal.SIGQUIT)
|
@@ -338,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|
338
338
|
"pixel_values": pixel_values,
|
339
339
|
"image_hashes": image_hashes,
|
340
340
|
"image_sizes": image_sizes,
|
341
|
-
"modalities": request_obj.modalities,
|
341
|
+
"modalities": request_obj.modalities or ["image"],
|
342
342
|
"image_grid_thws": image_grid_thws,
|
343
343
|
}
|
344
344
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -376,16 +376,6 @@ class ProfileReq(Enum):
|
|
376
376
|
STOP_PROFILE = 2
|
377
377
|
|
378
378
|
|
379
|
-
@dataclass
|
380
|
-
class GetMemPoolSizeReq:
|
381
|
-
pass
|
382
|
-
|
383
|
-
|
384
|
-
@dataclass
|
385
|
-
class GetMemPoolSizeReqOutput:
|
386
|
-
size: int
|
387
|
-
|
388
|
-
|
389
379
|
@dataclass
|
390
380
|
class OpenSessionReqInput:
|
391
381
|
capacity_of_str_len: int
|
@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
|
|
124
124
|
class ImageInputs:
|
125
125
|
"""The image related inputs."""
|
126
126
|
|
127
|
-
pixel_values: torch.Tensor
|
127
|
+
pixel_values: Union[torch.Tensor, np.array]
|
128
128
|
image_hashes: Optional[list] = None
|
129
129
|
image_sizes: Optional[list] = None
|
130
130
|
image_offsets: Optional[list] = None
|
@@ -132,7 +132,7 @@ class ImageInputs:
|
|
132
132
|
modalities: Optional[list] = None
|
133
133
|
num_image_tokens: Optional[int] = None
|
134
134
|
|
135
|
-
|
135
|
+
# Llava related
|
136
136
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
137
137
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
138
138
|
|
@@ -141,19 +141,17 @@ class ImageInputs:
|
|
141
141
|
mrope_position_delta: Optional[torch.Tensor] = None
|
142
142
|
|
143
143
|
@staticmethod
|
144
|
-
def from_dict(obj
|
145
|
-
# Use image hash as fake token_ids, which is then used for prefix matching
|
144
|
+
def from_dict(obj: dict):
|
146
145
|
ret = ImageInputs(
|
147
146
|
pixel_values=obj["pixel_values"],
|
148
|
-
image_hashes=
|
147
|
+
image_hashes=obj["image_hashes"],
|
149
148
|
)
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
]
|
149
|
+
|
150
|
+
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
151
|
+
# Please note that if the `input_ids` is later used in the model forward,
|
152
|
+
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
153
|
+
# errors in cuda kernels. See also llava.py for example.
|
154
|
+
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
|
157
155
|
|
158
156
|
optional_args = [
|
159
157
|
"image_sizes",
|
@@ -168,17 +166,16 @@ class ImageInputs:
|
|
168
166
|
|
169
167
|
return ret
|
170
168
|
|
171
|
-
def merge(self, other
|
169
|
+
def merge(self, other):
|
172
170
|
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
173
171
|
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
174
|
-
self.image_hashes += other.image_hashes
|
175
172
|
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
]
|
173
|
+
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
174
|
+
# Please note that if the `input_ids` is later used in the model forward,
|
175
|
+
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
176
|
+
# errors in cuda kernels. See also llava.py for example.
|
177
|
+
self.image_hashes += other.image_hashes
|
178
|
+
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
|
182
179
|
|
183
180
|
optional_args = [
|
184
181
|
"image_sizes",
|
@@ -231,6 +228,7 @@ class Req:
|
|
231
228
|
self.tokenizer = None
|
232
229
|
self.finished_reason = None
|
233
230
|
self.stream = False
|
231
|
+
self.to_abort = False
|
234
232
|
|
235
233
|
# For incremental decoding
|
236
234
|
# ----- | --------- read_ids -------|
|
@@ -290,11 +288,11 @@ class Req:
|
|
290
288
|
# The number of cached tokens, that were already cached in the KV cache
|
291
289
|
self.cached_tokens = 0
|
292
290
|
|
293
|
-
def extend_image_inputs(self, image_inputs
|
291
|
+
def extend_image_inputs(self, image_inputs):
|
294
292
|
if self.image_inputs is None:
|
295
293
|
self.image_inputs = image_inputs
|
296
294
|
else:
|
297
|
-
self.image_inputs.merge(image_inputs
|
295
|
+
self.image_inputs.merge(image_inputs)
|
298
296
|
|
299
297
|
# whether request reached finished condition
|
300
298
|
def finished(self) -> bool:
|
@@ -368,6 +366,10 @@ class Req:
|
|
368
366
|
if self.finished():
|
369
367
|
return
|
370
368
|
|
369
|
+
if self.to_abort:
|
370
|
+
self.finished_reason = FINISH_ABORT()
|
371
|
+
return
|
372
|
+
|
371
373
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
372
374
|
self.finished_reason = FINISH_LENGTH(
|
373
375
|
length=self.sampling_params.max_new_tokens
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import os
|
18
|
+
import signal
|
18
19
|
import threading
|
19
20
|
import time
|
20
21
|
import warnings
|
@@ -23,6 +24,7 @@ from concurrent import futures
|
|
23
24
|
from types import SimpleNamespace
|
24
25
|
from typing import List, Optional
|
25
26
|
|
27
|
+
import psutil
|
26
28
|
import torch
|
27
29
|
import zmq
|
28
30
|
|
@@ -36,8 +38,6 @@ from sglang.srt.managers.io_struct import (
|
|
36
38
|
BatchTokenIDOut,
|
37
39
|
CloseSessionReqInput,
|
38
40
|
FlushCacheReq,
|
39
|
-
GetMemPoolSizeReq,
|
40
|
-
GetMemPoolSizeReqOutput,
|
41
41
|
OpenSessionReqInput,
|
42
42
|
OpenSessionReqOutput,
|
43
43
|
ProfileReq,
|
@@ -73,7 +73,6 @@ from sglang.srt.utils import (
|
|
73
73
|
crash_on_warnings,
|
74
74
|
get_bool_env_var,
|
75
75
|
get_zmq_socket,
|
76
|
-
kill_parent_process,
|
77
76
|
set_gpu_proc_affinity,
|
78
77
|
set_random_seed,
|
79
78
|
suppress_other_loggers,
|
@@ -170,6 +169,10 @@ class Scheduler:
|
|
170
169
|
self.enable_overlap = False
|
171
170
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
172
171
|
|
172
|
+
if self.model_config.is_multimodal:
|
173
|
+
self.enable_overlap = False
|
174
|
+
logger.info("Overlap scheduler is disabled for multimodal models.")
|
175
|
+
|
173
176
|
if self.enable_overlap:
|
174
177
|
self.disable_jump_forward = True
|
175
178
|
|
@@ -312,6 +315,7 @@ class Scheduler:
|
|
312
315
|
self.watchdog_timeout = server_args.watchdog_timeout
|
313
316
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
314
317
|
t.start()
|
318
|
+
self.parent_process = psutil.Process().parent()
|
315
319
|
|
316
320
|
# Init profiler
|
317
321
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
@@ -355,7 +359,7 @@ class Scheduler:
|
|
355
359
|
self.watchdog_last_time = time.time()
|
356
360
|
time.sleep(self.watchdog_timeout / 2)
|
357
361
|
|
358
|
-
|
362
|
+
self.parent_process.send_signal(signal.SIGQUIT)
|
359
363
|
|
360
364
|
@torch.no_grad()
|
361
365
|
def event_loop_normal(self):
|
@@ -515,10 +519,6 @@ class Scheduler:
|
|
515
519
|
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
516
520
|
elif isinstance(recv_req, CloseSessionReqInput):
|
517
521
|
self.close_session(recv_req)
|
518
|
-
elif isinstance(recv_req, GetMemPoolSizeReq):
|
519
|
-
self.send_to_tokenizer.send_pyobj(
|
520
|
-
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
521
|
-
)
|
522
522
|
else:
|
523
523
|
raise ValueError(f"Invalid request: {recv_req}")
|
524
524
|
|
@@ -526,8 +526,9 @@ class Scheduler:
|
|
526
526
|
self,
|
527
527
|
recv_req: TokenizedGenerateReqInput,
|
528
528
|
):
|
529
|
+
# Create a new request
|
529
530
|
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
530
|
-
|
531
|
+
|
531
532
|
if recv_req.input_embeds is not None:
|
532
533
|
# Generate fake input_ids based on the length of input_embeds
|
533
534
|
seq_length = len(recv_req.input_embeds)
|
@@ -558,25 +559,30 @@ class Scheduler:
|
|
558
559
|
self.waiting_queue.append(req)
|
559
560
|
return
|
560
561
|
|
561
|
-
#
|
562
|
+
# Handle image inputs
|
562
563
|
if recv_req.image_inputs is not None:
|
563
|
-
image_inputs = ImageInputs.from_dict(
|
564
|
-
|
565
|
-
)
|
564
|
+
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
565
|
+
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
566
566
|
req.origin_input_ids = self.pad_input_ids_func(
|
567
567
|
req.origin_input_ids, image_inputs
|
568
568
|
)
|
569
|
-
req.extend_image_inputs(image_inputs
|
569
|
+
req.extend_image_inputs(image_inputs)
|
570
570
|
|
571
|
-
if len(req.origin_input_ids)
|
572
|
-
|
573
|
-
"
|
574
|
-
"
|
571
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
572
|
+
logger.error(
|
573
|
+
"Multimodal prompt is too long after expanding multimodal tokens. "
|
574
|
+
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
|
575
575
|
)
|
576
|
+
req.origin_input_ids = [0]
|
577
|
+
req.image_inputs = None
|
576
578
|
req.sampling_params.max_new_tokens = 0
|
579
|
+
req.finished_reason = FINISH_ABORT(
|
580
|
+
"Multimodal prompt is too long. Check server logs for details."
|
581
|
+
)
|
577
582
|
self.waiting_queue.append(req)
|
578
583
|
return
|
579
584
|
|
585
|
+
# Copy more attributes
|
580
586
|
req.return_logprob = recv_req.return_logprob
|
581
587
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
582
588
|
req.stream = recv_req.stream
|
@@ -1344,13 +1350,15 @@ class Scheduler:
|
|
1344
1350
|
|
1345
1351
|
if to_del is not None:
|
1346
1352
|
del self.waiting_queue[to_del]
|
1353
|
+
logger.debug(f"Abort queued request. {req.rid=}")
|
1354
|
+
return
|
1347
1355
|
|
1348
1356
|
# Delete requests in the running batch
|
1349
1357
|
if self.running_batch:
|
1350
1358
|
for req in self.running_batch.reqs:
|
1351
1359
|
if req.rid == recv_req.rid and not req.finished():
|
1352
|
-
req.
|
1353
|
-
|
1360
|
+
logger.debug(f"Abort running request. {req.rid=}")
|
1361
|
+
req.to_abort = True
|
1354
1362
|
break
|
1355
1363
|
|
1356
1364
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
@@ -1409,9 +1417,9 @@ def run_scheduler_process(
|
|
1409
1417
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1410
1418
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1411
1419
|
|
1412
|
-
# [For Router] if env var "
|
1413
|
-
if dp_rank is None and "
|
1414
|
-
dp_rank = int(os.environ["
|
1420
|
+
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1421
|
+
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
1422
|
+
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
1415
1423
|
|
1416
1424
|
if dp_rank is None:
|
1417
1425
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
@@ -1419,6 +1427,7 @@ def run_scheduler_process(
|
|
1419
1427
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
1420
1428
|
|
1421
1429
|
suppress_other_loggers()
|
1430
|
+
parent_process = psutil.Process().parent()
|
1422
1431
|
|
1423
1432
|
try:
|
1424
1433
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
@@ -1430,6 +1439,6 @@ def run_scheduler_process(
|
|
1430
1439
|
else:
|
1431
1440
|
scheduler.event_loop_normal()
|
1432
1441
|
except Exception:
|
1433
|
-
|
1434
|
-
logger.error(
|
1435
|
-
|
1442
|
+
traceback = get_exception_traceback()
|
1443
|
+
logger.error(f"Scheduler hit an exception: {traceback}")
|
1444
|
+
parent_process.send_signal(signal.SIGQUIT)
|
@@ -10,10 +10,7 @@
|
|
10
10
|
# limitations under the License.
|
11
11
|
# ==============================================================================
|
12
12
|
|
13
|
-
import copy
|
14
13
|
import uuid
|
15
|
-
from dataclasses import dataclass
|
16
|
-
from typing import Optional
|
17
14
|
|
18
15
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
19
16
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
@@ -45,8 +45,6 @@ from sglang.srt.managers.io_struct import (
|
|
45
45
|
EmbeddingReqInput,
|
46
46
|
FlushCacheReq,
|
47
47
|
GenerateReqInput,
|
48
|
-
GetMemPoolSizeReq,
|
49
|
-
GetMemPoolSizeReqOutput,
|
50
48
|
OpenSessionReqInput,
|
51
49
|
OpenSessionReqOutput,
|
52
50
|
ProfileReq,
|
@@ -58,7 +56,7 @@ from sglang.srt.managers.io_struct import (
|
|
58
56
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
59
57
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
60
58
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
61
|
-
from sglang.srt.utils import get_zmq_socket,
|
59
|
+
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
62
60
|
|
63
61
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
64
62
|
|
@@ -218,7 +216,8 @@ class TokenizerManager:
|
|
218
216
|
input_ids = obj.input_ids
|
219
217
|
|
220
218
|
if self.is_generation:
|
221
|
-
|
219
|
+
# TODO: also support getting embeddings for multimodal models
|
220
|
+
image_inputs: Dict = await self.image_processor.process_images_async(
|
222
221
|
obj.image_data, input_text or input_ids, obj
|
223
222
|
)
|
224
223
|
if image_inputs and "input_ids" in image_inputs:
|
@@ -406,25 +405,6 @@ class TokenizerManager:
|
|
406
405
|
req = ProfileReq.STOP_PROFILE
|
407
406
|
self.send_to_scheduler.send_pyobj(req)
|
408
407
|
|
409
|
-
async def get_memory_pool_size(self):
|
410
|
-
if self.to_create_loop:
|
411
|
-
self.create_handle_loop()
|
412
|
-
|
413
|
-
req = GetMemPoolSizeReq()
|
414
|
-
|
415
|
-
self.send_to_scheduler.send_pyobj(req)
|
416
|
-
self.mem_pool_size = asyncio.Future()
|
417
|
-
|
418
|
-
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
|
419
|
-
if self.server_args.dp_size == 1:
|
420
|
-
res = await self.mem_pool_size
|
421
|
-
return res.size
|
422
|
-
else: # self.server_args.dp_size > 1
|
423
|
-
self.mem_pool_size_tmp = []
|
424
|
-
res = await self.mem_pool_size
|
425
|
-
ret = [r.size for r in res]
|
426
|
-
return ret
|
427
|
-
|
428
408
|
async def update_weights(
|
429
409
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
430
410
|
):
|
@@ -532,7 +512,7 @@ class TokenizerManager:
|
|
532
512
|
else:
|
533
513
|
break
|
534
514
|
|
535
|
-
|
515
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
536
516
|
sys.exit(0)
|
537
517
|
|
538
518
|
async def handle_loop(self):
|
@@ -552,15 +532,6 @@ class TokenizerManager:
|
|
552
532
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
553
533
|
self.model_update_result.set_result(self.model_update_tmp)
|
554
534
|
continue
|
555
|
-
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
556
|
-
if self.server_args.dp_size == 1:
|
557
|
-
self.mem_pool_size.set_result(recv_obj)
|
558
|
-
else: # self.sever_args.dp_size > 1
|
559
|
-
self.mem_pool_size_tmp.append(recv_obj)
|
560
|
-
# set future if the all results are received
|
561
|
-
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
562
|
-
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
563
|
-
continue
|
564
535
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
565
536
|
self.session_futures[recv_obj.session_id].set_result(
|
566
537
|
recv_obj.session_id
|