sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +100 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/logits_processor.py +56 -19
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +101 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +46 -166
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +118 -24
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +32 -8
- sglang/srt/model_executor/forward_batch_info.py +51 -26
- sglang/srt/model_executor/model_runner.py +201 -58
- sglang/srt/models/gemma2.py +10 -6
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +11 -1
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/qwen2.py +9 -3
- sglang/srt/openai_api/adapter.py +200 -39
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +136 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
- sglang/srt/server.py +92 -57
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +22 -30
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
- sglang-0.2.14.post1.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -24,7 +24,6 @@ import json
|
|
24
24
|
import logging
|
25
25
|
import multiprocessing as mp
|
26
26
|
import os
|
27
|
-
import sys
|
28
27
|
import threading
|
29
28
|
import time
|
30
29
|
from http import HTTPStatus
|
@@ -34,7 +33,6 @@ from typing import Dict, List, Optional, Union
|
|
34
33
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
35
34
|
|
36
35
|
import aiohttp
|
37
|
-
import psutil
|
38
36
|
import requests
|
39
37
|
import uvicorn
|
40
38
|
import uvloop
|
@@ -52,11 +50,16 @@ from sglang.srt.managers.controller_single import (
|
|
52
50
|
start_controller_process as start_controller_process_single,
|
53
51
|
)
|
54
52
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
55
|
-
from sglang.srt.managers.io_struct import
|
53
|
+
from sglang.srt.managers.io_struct import (
|
54
|
+
EmbeddingReqInput,
|
55
|
+
GenerateReqInput,
|
56
|
+
UpdateWeightReqInput,
|
57
|
+
)
|
56
58
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
57
59
|
from sglang.srt.openai_api.adapter import (
|
58
60
|
load_chat_template_for_openai_api,
|
59
61
|
v1_batches,
|
62
|
+
v1_cancel_batch,
|
60
63
|
v1_chat_completions,
|
61
64
|
v1_completions,
|
62
65
|
v1_delete_file,
|
@@ -72,6 +75,7 @@ from sglang.srt.utils import (
|
|
72
75
|
add_api_key_middleware,
|
73
76
|
allocate_init_ports,
|
74
77
|
assert_pkg_version,
|
78
|
+
configure_logger,
|
75
79
|
enable_show_time_cost,
|
76
80
|
kill_child_process,
|
77
81
|
maybe_set_triton_cache_manager,
|
@@ -92,10 +96,25 @@ tokenizer_manager = None
|
|
92
96
|
|
93
97
|
@app.get("/health")
|
94
98
|
async def health() -> Response:
|
95
|
-
"""
|
99
|
+
"""Check the health of the http server."""
|
96
100
|
return Response(status_code=200)
|
97
101
|
|
98
102
|
|
103
|
+
@app.get("/health_generate")
|
104
|
+
async def health_generate(request: Request) -> Response:
|
105
|
+
"""Check the health of the inference server by generating one token."""
|
106
|
+
gri = GenerateReqInput(
|
107
|
+
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
108
|
+
)
|
109
|
+
try:
|
110
|
+
async for _ in tokenizer_manager.generate_request(gri, request):
|
111
|
+
break
|
112
|
+
return Response(status_code=200)
|
113
|
+
except Exception as e:
|
114
|
+
logger.exception(e)
|
115
|
+
return Response(status_code=503)
|
116
|
+
|
117
|
+
|
99
118
|
@app.get("/get_model_info")
|
100
119
|
async def get_model_info():
|
101
120
|
result = {
|
@@ -120,6 +139,23 @@ async def flush_cache():
|
|
120
139
|
)
|
121
140
|
|
122
141
|
|
142
|
+
@app.post("/update_weights")
|
143
|
+
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
144
|
+
|
145
|
+
success, message = await tokenizer_manager.update_weights(obj, request)
|
146
|
+
content = {"message": message, "success": str(success)}
|
147
|
+
if success:
|
148
|
+
return JSONResponse(
|
149
|
+
content,
|
150
|
+
status_code=HTTPStatus.OK,
|
151
|
+
)
|
152
|
+
else:
|
153
|
+
return JSONResponse(
|
154
|
+
content,
|
155
|
+
status_code=HTTPStatus.BAD_REQUEST,
|
156
|
+
)
|
157
|
+
|
158
|
+
|
123
159
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
124
160
|
"""Handle a generate request."""
|
125
161
|
if obj.stream:
|
@@ -211,6 +247,12 @@ async def openai_v1_batches(raw_request: Request):
|
|
211
247
|
return await v1_batches(tokenizer_manager, raw_request)
|
212
248
|
|
213
249
|
|
250
|
+
@app.post("/v1/batches/{batch_id}/cancel")
|
251
|
+
async def cancel_batches(batch_id: str):
|
252
|
+
# https://platform.openai.com/docs/api-reference/batch/cancel
|
253
|
+
return await v1_cancel_batch(tokenizer_manager, batch_id)
|
254
|
+
|
255
|
+
|
214
256
|
@app.get("/v1/batches/{batch_id}")
|
215
257
|
async def retrieve_batch(batch_id: str):
|
216
258
|
return await v1_retrieve_batch(batch_id)
|
@@ -236,15 +278,12 @@ def launch_server(
|
|
236
278
|
"""Launch an HTTP server."""
|
237
279
|
global tokenizer_manager
|
238
280
|
|
239
|
-
|
240
|
-
level=getattr(logging, server_args.log_level.upper()),
|
241
|
-
format="%(message)s",
|
242
|
-
)
|
281
|
+
configure_logger(server_args)
|
243
282
|
|
244
283
|
server_args.check_server_args()
|
245
284
|
_set_envs_and_config(server_args)
|
246
285
|
|
247
|
-
# Allocate ports
|
286
|
+
# Allocate ports for inter-process communications
|
248
287
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
249
288
|
server_args.port,
|
250
289
|
server_args.additional_ports,
|
@@ -264,27 +303,29 @@ def launch_server(
|
|
264
303
|
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
265
304
|
|
266
305
|
# Launch processes for multi-node tensor parallelism
|
267
|
-
if server_args.nnodes > 1:
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
range(
|
275
|
-
server_args.node_rank * tp_size_local,
|
276
|
-
(server_args.node_rank + 1) * tp_size_local,
|
277
|
-
)
|
306
|
+
if server_args.nnodes > 1 and server_args.node_rank != 0:
|
307
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
308
|
+
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
309
|
+
tp_rank_range = list(
|
310
|
+
range(
|
311
|
+
server_args.node_rank * tp_size_local,
|
312
|
+
(server_args.node_rank + 1) * tp_size_local,
|
278
313
|
)
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
314
|
+
)
|
315
|
+
procs = launch_tp_servers(
|
316
|
+
gpu_ids,
|
317
|
+
tp_rank_range,
|
318
|
+
server_args,
|
319
|
+
ports[3],
|
320
|
+
model_overide_args,
|
321
|
+
)
|
322
|
+
|
323
|
+
try:
|
324
|
+
for p in procs:
|
325
|
+
p.join()
|
326
|
+
finally:
|
327
|
+
kill_child_process(os.getpid(), including_parent=False)
|
328
|
+
return
|
288
329
|
|
289
330
|
# Launch processes
|
290
331
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
@@ -297,11 +338,13 @@ def launch_server(
|
|
297
338
|
start_process = start_controller_process_single
|
298
339
|
else:
|
299
340
|
start_process = start_controller_process_multi
|
341
|
+
|
300
342
|
proc_controller = mp.Process(
|
301
343
|
target=start_process,
|
302
344
|
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
303
345
|
)
|
304
346
|
proc_controller.start()
|
347
|
+
|
305
348
|
proc_detoken = mp.Process(
|
306
349
|
target=start_detokenizer_process,
|
307
350
|
args=(
|
@@ -319,15 +362,11 @@ def launch_server(
|
|
319
362
|
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
320
363
|
proc_controller.kill()
|
321
364
|
proc_detoken.kill()
|
322
|
-
|
323
|
-
|
324
|
-
|
365
|
+
raise RuntimeError(
|
366
|
+
"Initialization failed. "
|
367
|
+
f"controller_init_state: {controller_init_state}, "
|
368
|
+
f"detoken_init_state: {detoken_init_state}"
|
325
369
|
)
|
326
|
-
print(
|
327
|
-
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
328
|
-
flush=True,
|
329
|
-
)
|
330
|
-
sys.exit(1)
|
331
370
|
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
332
371
|
|
333
372
|
# Add api key authorization
|
@@ -336,12 +375,12 @@ def launch_server(
|
|
336
375
|
|
337
376
|
# Send a warmup request
|
338
377
|
t = threading.Thread(
|
339
|
-
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
378
|
+
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
|
340
379
|
)
|
341
380
|
t.start()
|
342
381
|
|
343
|
-
# Listen for requests
|
344
382
|
try:
|
383
|
+
# Listen for requests
|
345
384
|
uvicorn.run(
|
346
385
|
app,
|
347
386
|
host=server_args.host,
|
@@ -382,14 +421,14 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
382
421
|
if not server_args.disable_flashinfer:
|
383
422
|
assert_pkg_version(
|
384
423
|
"flashinfer",
|
385
|
-
"0.1.
|
424
|
+
"0.1.6",
|
386
425
|
"Please uninstall the old version and "
|
387
426
|
"reinstall the latest version by following the instructions "
|
388
427
|
"at https://docs.flashinfer.ai/installation.html.",
|
389
428
|
)
|
390
429
|
|
391
430
|
|
392
|
-
def _wait_and_warmup(server_args, pipe_finish_writer):
|
431
|
+
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
393
432
|
headers = {}
|
394
433
|
url = server_args.url()
|
395
434
|
if server_args.api_key:
|
@@ -412,8 +451,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
412
451
|
if not success:
|
413
452
|
if pipe_finish_writer is not None:
|
414
453
|
pipe_finish_writer.send(last_traceback)
|
415
|
-
|
416
|
-
|
454
|
+
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
455
|
+
kill_child_process(pid, including_parent=False)
|
456
|
+
return
|
417
457
|
|
418
458
|
# Send a warmup request
|
419
459
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
@@ -438,21 +478,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
438
478
|
timeout=600,
|
439
479
|
)
|
440
480
|
assert res.status_code == 200, f"{res}"
|
441
|
-
except Exception
|
481
|
+
except Exception:
|
442
482
|
last_traceback = get_exception_traceback()
|
443
483
|
if pipe_finish_writer is not None:
|
444
484
|
pipe_finish_writer.send(last_traceback)
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
# Print warnings here
|
449
|
-
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
|
450
|
-
logger.warning(
|
451
|
-
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
|
452
|
-
"This combination is an experimental feature and we noticed it can lead to "
|
453
|
-
"wrong generation results. If you want to use chunked prefill, it is recommended "
|
454
|
-
"not using `--disable-radix-cache`."
|
455
|
-
)
|
485
|
+
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
486
|
+
kill_child_process(pid, including_parent=False)
|
487
|
+
return
|
456
488
|
|
457
489
|
logger.info("The server is fired up and ready to roll!")
|
458
490
|
if pipe_finish_writer is not None:
|
@@ -490,6 +522,7 @@ class Runtime:
|
|
490
522
|
|
491
523
|
self.pid = None
|
492
524
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
525
|
+
|
493
526
|
proc = mp.Process(
|
494
527
|
target=launch_server,
|
495
528
|
args=(self.server_args, model_overide_args, pipe_writer),
|
@@ -566,15 +599,17 @@ class Runtime:
|
|
566
599
|
|
567
600
|
def generate(
|
568
601
|
self,
|
569
|
-
prompt: str,
|
602
|
+
prompt: Union[str, List[str]],
|
570
603
|
sampling_params: Optional[Dict] = None,
|
571
604
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
605
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
572
606
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
573
607
|
):
|
574
608
|
json_data = {
|
575
609
|
"text": prompt,
|
576
610
|
"sampling_params": sampling_params,
|
577
611
|
"return_logprob": return_logprob,
|
612
|
+
"logprob_start_len": logprob_start_len,
|
578
613
|
"top_logprobs_num": top_logprobs_num,
|
579
614
|
}
|
580
615
|
response = requests.post(
|
@@ -585,7 +620,7 @@ class Runtime:
|
|
585
620
|
|
586
621
|
def encode(
|
587
622
|
self,
|
588
|
-
prompt: str,
|
623
|
+
prompt: Union[str, List[str]],
|
589
624
|
):
|
590
625
|
json_data = {
|
591
626
|
"text": prompt,
|
sglang/srt/server_args.py
CHANGED
@@ -33,11 +33,13 @@ class ServerArgs:
|
|
33
33
|
skip_tokenizer_init: bool = False
|
34
34
|
load_format: str = "auto"
|
35
35
|
dtype: str = "auto"
|
36
|
+
kv_cache_dtype: str = "auto"
|
36
37
|
trust_remote_code: bool = True
|
37
38
|
context_length: Optional[int] = None
|
38
39
|
quantization: Optional[str] = None
|
39
40
|
served_model_name: Optional[str] = None
|
40
41
|
chat_template: Optional[str] = None
|
42
|
+
is_embedding: bool = False
|
41
43
|
|
42
44
|
# Port
|
43
45
|
host: str = "127.0.0.1"
|
@@ -79,12 +81,14 @@ class ServerArgs:
|
|
79
81
|
disable_radix_cache: bool = False
|
80
82
|
disable_regex_jump_forward: bool = False
|
81
83
|
disable_cuda_graph: bool = False
|
84
|
+
disable_cuda_graph_padding: bool = False
|
82
85
|
disable_disk_cache: bool = False
|
86
|
+
disable_custom_all_reduce: bool = False
|
87
|
+
enable_mixed_chunk: bool = False
|
83
88
|
enable_torch_compile: bool = False
|
84
89
|
enable_p2p_check: bool = False
|
85
90
|
enable_mla: bool = False
|
86
|
-
|
87
|
-
efficient_weight_load: bool = False
|
91
|
+
triton_attention_reduce_in_fp32: bool = False
|
88
92
|
|
89
93
|
# Distributed args
|
90
94
|
nccl_init_addr: Optional[str] = None
|
@@ -193,11 +197,23 @@ class ServerArgs:
|
|
193
197
|
'* "float" is shorthand for FP32 precision.\n'
|
194
198
|
'* "float32" for FP32 precision.',
|
195
199
|
)
|
200
|
+
parser.add_argument(
|
201
|
+
"--kv-cache-dtype",
|
202
|
+
type=str,
|
203
|
+
default=ServerArgs.kv_cache_dtype,
|
204
|
+
choices=["auto", "fp8_e5m2"],
|
205
|
+
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
206
|
+
)
|
196
207
|
parser.add_argument(
|
197
208
|
"--trust-remote-code",
|
198
209
|
action="store_true",
|
199
210
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
200
211
|
)
|
212
|
+
parser.add_argument(
|
213
|
+
"--is-embedding",
|
214
|
+
action="store_true",
|
215
|
+
help="Whether to use a CausalLM as an embedding model.",
|
216
|
+
)
|
201
217
|
parser.add_argument(
|
202
218
|
"--context-length",
|
203
219
|
type=int,
|
@@ -391,11 +407,27 @@ class ServerArgs:
|
|
391
407
|
action="store_true",
|
392
408
|
help="Disable cuda graph.",
|
393
409
|
)
|
410
|
+
parser.add_argument(
|
411
|
+
"--disable-cuda-graph-padding",
|
412
|
+
action="store_true",
|
413
|
+
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
|
414
|
+
)
|
394
415
|
parser.add_argument(
|
395
416
|
"--disable-disk-cache",
|
396
417
|
action="store_true",
|
397
418
|
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
398
419
|
)
|
420
|
+
parser.add_argument(
|
421
|
+
"--disable-custom-all-reduce",
|
422
|
+
action="store_true",
|
423
|
+
default=False,
|
424
|
+
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
425
|
+
)
|
426
|
+
parser.add_argument(
|
427
|
+
"--enable-mixed-chunk",
|
428
|
+
action="store_true",
|
429
|
+
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
|
430
|
+
)
|
399
431
|
parser.add_argument(
|
400
432
|
"--enable-torch-compile",
|
401
433
|
action="store_true",
|
@@ -409,13 +441,13 @@ class ServerArgs:
|
|
409
441
|
parser.add_argument(
|
410
442
|
"--enable-mla",
|
411
443
|
action="store_true",
|
412
|
-
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
|
444
|
+
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
413
445
|
)
|
414
446
|
parser.add_argument(
|
415
|
-
"--attention-reduce-in-fp32",
|
447
|
+
"--triton-attention-reduce-in-fp32",
|
416
448
|
action="store_true",
|
417
449
|
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
418
|
-
"This only affects Triton attention kernels",
|
450
|
+
"This only affects Triton attention kernels.",
|
419
451
|
)
|
420
452
|
parser.add_argument(
|
421
453
|
"--efficient-weight-load",
|
@@ -433,15 +465,6 @@ class ServerArgs:
|
|
433
465
|
def url(self):
|
434
466
|
return f"http://{self.host}:{self.port}"
|
435
467
|
|
436
|
-
def print_mode_args(self):
|
437
|
-
return (
|
438
|
-
f"disable_flashinfer={self.disable_flashinfer}, "
|
439
|
-
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
440
|
-
f"disable_radix_cache={self.disable_radix_cache}, "
|
441
|
-
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
442
|
-
f"disable_disk_cache={self.disable_disk_cache}, "
|
443
|
-
)
|
444
|
-
|
445
468
|
def check_server_args(self):
|
446
469
|
assert (
|
447
470
|
self.tp_size % self.nnodes == 0
|
@@ -449,8 +472,13 @@ class ServerArgs:
|
|
449
472
|
assert not (
|
450
473
|
self.dp_size > 1 and self.node_rank is not None
|
451
474
|
), "multi-node data parallel is not supported"
|
475
|
+
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
476
|
+
logger.info(
|
477
|
+
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
478
|
+
)
|
479
|
+
self.trust_remote_code = False
|
452
480
|
if "gemma-2" in self.model_path.lower():
|
453
|
-
logger.info(
|
481
|
+
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
454
482
|
self.disable_flashinfer = False
|
455
483
|
|
456
484
|
|
sglang/srt/utils.py
CHANGED
@@ -224,13 +224,18 @@ def is_multimodal_model(model):
|
|
224
224
|
raise ValueError("unrecognized type")
|
225
225
|
|
226
226
|
|
227
|
-
def is_generation_model(model_architectures):
|
227
|
+
def is_generation_model(model_architectures, is_embedding: bool = False):
|
228
|
+
# We have two ways to determine whether a model is a generative model.
|
229
|
+
# 1. Check the model architectue
|
230
|
+
# 2. check the `is_embedding` server args
|
231
|
+
|
228
232
|
if (
|
229
233
|
"LlamaEmbeddingModel" in model_architectures
|
230
234
|
or "MistralModel" in model_architectures
|
231
235
|
):
|
232
236
|
return False
|
233
|
-
|
237
|
+
else:
|
238
|
+
return not is_embedding
|
234
239
|
|
235
240
|
|
236
241
|
def decode_video_base64(video_base64):
|
@@ -347,7 +352,7 @@ def suppress_other_loggers():
|
|
347
352
|
logging.WARN
|
348
353
|
)
|
349
354
|
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
350
|
-
logging.getLogger("vllm.utils").setLevel(logging.
|
355
|
+
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
351
356
|
|
352
357
|
|
353
358
|
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
@@ -369,14 +374,11 @@ def kill_parent_process():
|
|
369
374
|
"""Kill the parent process and all children of the parent process."""
|
370
375
|
current_process = psutil.Process()
|
371
376
|
parent_process = current_process.parent()
|
372
|
-
|
373
|
-
for child in children:
|
374
|
-
if child.pid != current_process.pid:
|
375
|
-
os.kill(child.pid, 9)
|
376
|
-
os.kill(parent_process.pid, 9)
|
377
|
+
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
|
377
378
|
|
378
379
|
|
379
|
-
def kill_child_process(pid, including_parent=True):
|
380
|
+
def kill_child_process(pid, including_parent=True, skip_pid=None):
|
381
|
+
"""Kill the process and all its children process."""
|
380
382
|
try:
|
381
383
|
parent = psutil.Process(pid)
|
382
384
|
except psutil.NoSuchProcess:
|
@@ -384,6 +386,8 @@ def kill_child_process(pid, including_parent=True):
|
|
384
386
|
|
385
387
|
children = parent.children(recursive=True)
|
386
388
|
for child in children:
|
389
|
+
if child.pid == skip_pid:
|
390
|
+
continue
|
387
391
|
try:
|
388
392
|
child.kill()
|
389
393
|
except psutil.NoSuchProcess:
|
@@ -452,10 +456,6 @@ def monkey_patch_vllm_dummy_weight_loader():
|
|
452
456
|
quant_method = getattr(module, "quant_method", None)
|
453
457
|
if quant_method is not None:
|
454
458
|
quant_method.process_weights_after_loading(module)
|
455
|
-
# FIXME: Remove this after Mixtral is updated
|
456
|
-
# to use quant_method.
|
457
|
-
if hasattr(module, "process_weights_after_loading"):
|
458
|
-
module.process_weights_after_loading()
|
459
459
|
|
460
460
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
461
461
|
# random values to the weights.
|
@@ -692,7 +692,7 @@ def monkey_patch_vllm_qvk_linear_loader():
|
|
692
692
|
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
693
693
|
|
694
694
|
|
695
|
-
def add_api_key_middleware(app, api_key):
|
695
|
+
def add_api_key_middleware(app, api_key: str):
|
696
696
|
@app.middleware("http")
|
697
697
|
async def authentication(request, call_next):
|
698
698
|
if request.method == "OPTIONS":
|
@@ -704,7 +704,7 @@ def add_api_key_middleware(app, api_key):
|
|
704
704
|
return await call_next(request)
|
705
705
|
|
706
706
|
|
707
|
-
def prepare_model(model_path):
|
707
|
+
def prepare_model(model_path: str):
|
708
708
|
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
709
709
|
if not os.path.exists(model_path):
|
710
710
|
from modelscope import snapshot_download
|
@@ -713,7 +713,7 @@ def prepare_model(model_path):
|
|
713
713
|
return model_path
|
714
714
|
|
715
715
|
|
716
|
-
def prepare_tokenizer(tokenizer_path):
|
716
|
+
def prepare_tokenizer(tokenizer_path: str):
|
717
717
|
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
718
718
|
if not os.path.exists(tokenizer_path):
|
719
719
|
from modelscope import snapshot_download
|
@@ -722,3 +722,13 @@ def prepare_tokenizer(tokenizer_path):
|
|
722
722
|
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
723
723
|
)
|
724
724
|
return tokenizer_path
|
725
|
+
|
726
|
+
|
727
|
+
def configure_logger(server_args, prefix: str = ""):
|
728
|
+
format = f"[%(asctime)s{prefix}] %(message)s"
|
729
|
+
logging.basicConfig(
|
730
|
+
level=getattr(logging, server_args.log_level.upper()),
|
731
|
+
format=format,
|
732
|
+
datefmt="%H:%M:%S",
|
733
|
+
force=True,
|
734
|
+
)
|
sglang/test/runners.py
CHANGED
@@ -14,7 +14,7 @@ limitations under the License.
|
|
14
14
|
"""
|
15
15
|
|
16
16
|
import json
|
17
|
-
import multiprocessing
|
17
|
+
import multiprocessing as mp
|
18
18
|
import os
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from typing import List, Union
|
@@ -24,15 +24,15 @@ import torch.nn.functional as F
|
|
24
24
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
25
25
|
|
26
26
|
from sglang.srt.server import Runtime
|
27
|
-
from sglang.
|
27
|
+
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
28
28
|
|
29
29
|
DEFAULT_PROMPTS = [
|
30
30
|
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
31
31
|
# "The capital of France is",
|
32
|
+
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
32
33
|
"The capital of the United Kindom is",
|
33
34
|
"Today is a sunny day and I like",
|
34
35
|
"AI is a field of computer science focused on",
|
35
|
-
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
36
36
|
]
|
37
37
|
|
38
38
|
dirpath = os.path.dirname(__file__)
|
@@ -63,44 +63,37 @@ class HFRunner:
|
|
63
63
|
def __init__(
|
64
64
|
self,
|
65
65
|
model_path,
|
66
|
-
torch_dtype
|
67
|
-
|
66
|
+
torch_dtype,
|
67
|
+
is_generation,
|
68
68
|
):
|
69
|
-
self.
|
70
|
-
|
69
|
+
self.is_generation = is_generation
|
70
|
+
|
71
|
+
self.in_queue = mp.Queue()
|
72
|
+
self.out_queue = mp.Queue()
|
71
73
|
|
72
|
-
self.model_proc =
|
74
|
+
self.model_proc = mp.Process(
|
73
75
|
target=self.start_model_process,
|
74
76
|
args=(
|
75
77
|
self.in_queue,
|
76
78
|
self.out_queue,
|
77
79
|
model_path,
|
78
80
|
torch_dtype,
|
79
|
-
is_generation_model,
|
80
81
|
),
|
81
82
|
)
|
82
83
|
self.model_proc.start()
|
83
84
|
|
84
|
-
def start_model_process(
|
85
|
-
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
|
86
|
-
):
|
85
|
+
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
87
86
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
88
87
|
model_path,
|
89
88
|
torch_dtype=torch_dtype,
|
90
|
-
trust_remote_code=True,
|
91
89
|
)
|
92
90
|
|
93
|
-
self.
|
94
|
-
is_generation_model(model_path)
|
95
|
-
if is_generation_model is None
|
96
|
-
else is_generation_model
|
97
|
-
)
|
98
|
-
if self.is_generation_model:
|
91
|
+
if self.is_generation:
|
99
92
|
self.model = AutoModelForCausalLM.from_pretrained(
|
100
93
|
model_path,
|
101
94
|
torch_dtype=torch_dtype,
|
95
|
+
trust_remote_code=False,
|
102
96
|
low_cpu_mem_usage=True,
|
103
|
-
trust_remote_code=True,
|
104
97
|
).cuda()
|
105
98
|
else:
|
106
99
|
from sentence_transformers import SentenceTransformer
|
@@ -113,7 +106,7 @@ class HFRunner:
|
|
113
106
|
while True:
|
114
107
|
prompts, max_new_tokens = in_queue.get()
|
115
108
|
if prompts is not None:
|
116
|
-
if self.
|
109
|
+
if self.is_generation:
|
117
110
|
output_strs = []
|
118
111
|
prefill_logprobs = []
|
119
112
|
for p in prompts:
|
@@ -176,22 +169,20 @@ class SRTRunner:
|
|
176
169
|
def __init__(
|
177
170
|
self,
|
178
171
|
model_path,
|
172
|
+
torch_dtype,
|
173
|
+
is_generation,
|
179
174
|
tp_size=1,
|
180
|
-
|
181
|
-
is_generation_model=None,
|
182
|
-
port=5157,
|
175
|
+
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
183
176
|
):
|
184
|
-
self.
|
185
|
-
is_generation_model(model_path)
|
186
|
-
if is_generation_model is None
|
187
|
-
else is_generation_model
|
188
|
-
)
|
177
|
+
self.is_generation = is_generation
|
189
178
|
self.runtime = Runtime(
|
190
179
|
model_path=model_path,
|
191
180
|
tp_size=tp_size,
|
192
181
|
dtype=get_dtype_str(torch_dtype),
|
193
182
|
port=port,
|
194
183
|
mem_fraction_static=0.7,
|
184
|
+
trust_remote_code=False,
|
185
|
+
is_embedding=not self.is_generation,
|
195
186
|
)
|
196
187
|
|
197
188
|
def forward(
|
@@ -199,7 +190,7 @@ class SRTRunner:
|
|
199
190
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
200
191
|
max_new_tokens=8,
|
201
192
|
):
|
202
|
-
if self.
|
193
|
+
if self.is_generation:
|
203
194
|
# the return value contains logprobs from prefill
|
204
195
|
output_strs = []
|
205
196
|
top_input_logprobs = []
|
@@ -209,6 +200,7 @@ class SRTRunner:
|
|
209
200
|
prompt,
|
210
201
|
sampling_params=sampling_params,
|
211
202
|
return_logprob=True,
|
203
|
+
logprob_start_len=0,
|
212
204
|
top_logprobs_num=NUM_TOP_LOGPROBS,
|
213
205
|
)
|
214
206
|
response = json.loads(response)
|