sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +51 -13
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +6 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +7 -6
- sglang/srt/managers/detokenizer_manager.py +9 -11
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +70 -78
- sglang/srt/managers/schedule_batch.py +33 -49
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +137 -80
- sglang/srt/managers/tokenizer_manager.py +224 -336
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/model_runner.py +8 -17
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/server.py +31 -35
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/runners.py +2 -1
- sglang/test/test_utils.py +73 -25
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post2.dist-info/METADATA +0 -899
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -71,6 +71,7 @@ from sglang.srt.openai_api.protocol import (
|
|
71
71
|
TopLogprob,
|
72
72
|
UsageInfo,
|
73
73
|
)
|
74
|
+
from sglang.utils import get_exception_traceback
|
74
75
|
|
75
76
|
logger = logging.getLogger(__name__)
|
76
77
|
|
@@ -314,6 +315,8 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
314
315
|
)
|
315
316
|
|
316
317
|
except Exception as e:
|
318
|
+
logger.error(f"error: {get_exception_traceback()}")
|
319
|
+
responses = []
|
317
320
|
error_json = {
|
318
321
|
"id": f"batch_req_{uuid.uuid4()}",
|
319
322
|
"custom_id": request_data.get("custom_id"),
|
@@ -363,7 +366,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
363
366
|
}
|
364
367
|
|
365
368
|
except Exception as e:
|
366
|
-
logger.error("error
|
369
|
+
logger.error(f"error: {e}")
|
367
370
|
# Update batch status to "failed"
|
368
371
|
retrieve_batch = batch_storage[batch_id]
|
369
372
|
retrieve_batch.status = "failed"
|
@@ -469,80 +472,67 @@ async def v1_retrieve_file_content(file_id: str):
|
|
469
472
|
def v1_generate_request(
|
470
473
|
all_requests: List[CompletionRequest], request_ids: List[str] = None
|
471
474
|
):
|
475
|
+
if len(all_requests) > 1:
|
476
|
+
first_prompt_type = type(all_requests[0].prompt)
|
477
|
+
for request in all_requests:
|
478
|
+
assert (
|
479
|
+
type(request.prompt) is first_prompt_type
|
480
|
+
), "All prompts must be of the same type in file input settings"
|
481
|
+
if request.n > 1:
|
482
|
+
raise ValueError(
|
483
|
+
"Parallel sampling is not supported for completions from files"
|
484
|
+
)
|
485
|
+
|
472
486
|
prompts = []
|
473
487
|
sampling_params_list = []
|
474
488
|
return_logprobs = []
|
475
489
|
logprob_start_lens = []
|
476
490
|
top_logprobs_nums = []
|
477
491
|
|
478
|
-
# NOTE: with openai API, the prompt's logprobs are always not computed
|
479
|
-
first_prompt_type = type(all_requests[0].prompt)
|
480
492
|
for request in all_requests:
|
481
|
-
|
482
|
-
type(request.prompt) is first_prompt_type
|
483
|
-
), "All prompts must be of the same type in file input settings"
|
484
|
-
if len(all_requests) > 1 and request.n > 1:
|
485
|
-
raise ValueError(
|
486
|
-
"Parallel sampling is not supported for completions from files"
|
487
|
-
)
|
493
|
+
# NOTE: with openai API, the prompt's logprobs are always not computed
|
488
494
|
if request.echo and request.logprobs:
|
489
495
|
logger.warning(
|
490
496
|
"Echo is not compatible with logprobs. "
|
491
|
-
"To compute logprobs of input prompt, please use
|
497
|
+
"To compute logprobs of input prompt, please use the native /generate API."
|
492
498
|
)
|
493
499
|
|
494
|
-
for request in all_requests:
|
495
500
|
prompts.append(request.prompt)
|
501
|
+
sampling_params_list.append(
|
502
|
+
{
|
503
|
+
"temperature": request.temperature,
|
504
|
+
"max_new_tokens": request.max_tokens,
|
505
|
+
"min_new_tokens": request.min_tokens,
|
506
|
+
"stop": request.stop,
|
507
|
+
"stop_token_ids": request.stop_token_ids,
|
508
|
+
"top_p": request.top_p,
|
509
|
+
"presence_penalty": request.presence_penalty,
|
510
|
+
"frequency_penalty": request.frequency_penalty,
|
511
|
+
"repetition_penalty": request.repetition_penalty,
|
512
|
+
"regex": request.regex,
|
513
|
+
"json_schema": request.json_schema,
|
514
|
+
"n": request.n,
|
515
|
+
"ignore_eos": request.ignore_eos,
|
516
|
+
"no_stop_trim": request.no_stop_trim,
|
517
|
+
}
|
518
|
+
)
|
496
519
|
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
497
520
|
logprob_start_lens.append(-1)
|
498
521
|
top_logprobs_nums.append(
|
499
522
|
request.logprobs if request.logprobs is not None else 0
|
500
523
|
)
|
501
|
-
sampling_params = []
|
502
|
-
if isinstance(request.no_stop_trim, list):
|
503
|
-
num_reqs = len(request.prompt)
|
504
|
-
else:
|
505
|
-
num_reqs = 1
|
506
|
-
for i in range(num_reqs):
|
507
|
-
sampling_params.append(
|
508
|
-
{
|
509
|
-
"temperature": request.temperature,
|
510
|
-
"max_new_tokens": request.max_tokens,
|
511
|
-
"min_new_tokens": request.min_tokens,
|
512
|
-
"stop": request.stop,
|
513
|
-
"stop_token_ids": request.stop_token_ids,
|
514
|
-
"top_p": request.top_p,
|
515
|
-
"presence_penalty": request.presence_penalty,
|
516
|
-
"frequency_penalty": request.frequency_penalty,
|
517
|
-
"repetition_penalty": request.repetition_penalty,
|
518
|
-
"regex": request.regex,
|
519
|
-
"json_schema": request.json_schema,
|
520
|
-
"n": request.n,
|
521
|
-
"ignore_eos": request.ignore_eos,
|
522
|
-
"no_stop_trim": (
|
523
|
-
request.no_stop_trim
|
524
|
-
if not isinstance(request.no_stop_trim, list)
|
525
|
-
else request.no_stop_trim[i]
|
526
|
-
),
|
527
|
-
}
|
528
|
-
)
|
529
|
-
if num_reqs == 1:
|
530
|
-
sampling_params_list.append(sampling_params[0])
|
531
|
-
else:
|
532
|
-
sampling_params_list.append(sampling_params)
|
533
524
|
|
534
525
|
if len(all_requests) == 1:
|
535
|
-
|
526
|
+
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
527
|
+
prompt_kwargs = {"text": prompts[0]}
|
528
|
+
else:
|
529
|
+
prompt_kwargs = {"input_ids": prompts[0]}
|
536
530
|
sampling_params_list = sampling_params_list[0]
|
537
|
-
logprob_start_lens = logprob_start_lens[0]
|
538
531
|
return_logprobs = return_logprobs[0]
|
532
|
+
logprob_start_lens = logprob_start_lens[0]
|
539
533
|
top_logprobs_nums = top_logprobs_nums[0]
|
540
|
-
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
541
|
-
prompt_kwargs = {"text": prompt}
|
542
|
-
else:
|
543
|
-
prompt_kwargs = {"input_ids": prompt}
|
544
534
|
else:
|
545
|
-
if isinstance(prompts[0], str):
|
535
|
+
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
546
536
|
prompt_kwargs = {"text": prompts}
|
547
537
|
else:
|
548
538
|
prompt_kwargs = {"input_ids": prompts}
|
@@ -558,9 +548,7 @@ def v1_generate_request(
|
|
558
548
|
rid=request_ids,
|
559
549
|
)
|
560
550
|
|
561
|
-
if len(all_requests)
|
562
|
-
return adapted_request, all_requests[0]
|
563
|
-
return adapted_request, all_requests
|
551
|
+
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
564
552
|
|
565
553
|
|
566
554
|
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
@@ -595,7 +583,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
595
583
|
if isinstance(request, list) and request[idx].echo:
|
596
584
|
echo = True
|
597
585
|
text = request[idx].prompt + text
|
598
|
-
if
|
586
|
+
if echo and not isinstance(request, list):
|
599
587
|
prompt_index = idx // request.n
|
600
588
|
text = prompts[prompt_index] + text
|
601
589
|
|
@@ -709,7 +697,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
709
697
|
async for content in tokenizer_manager.generate_request(
|
710
698
|
adapted_request, raw_request
|
711
699
|
):
|
712
|
-
index = content
|
700
|
+
index = content.get("index", 0)
|
713
701
|
|
714
702
|
stream_buffer = stream_buffers.get(index, "")
|
715
703
|
n_prev_token = n_prev_tokens.get(index, 0)
|
@@ -945,19 +933,18 @@ def v1_chat_generate_request(
|
|
945
933
|
sampling_params_list.append(sampling_params)
|
946
934
|
|
947
935
|
image_data_list.append(image_data)
|
948
|
-
modalities_list.
|
936
|
+
modalities_list.append(modalities)
|
949
937
|
if len(all_requests) == 1:
|
950
|
-
|
951
|
-
|
952
|
-
prompt_kwargs = {"text": input_ids}
|
938
|
+
if isinstance(input_ids[0], str):
|
939
|
+
prompt_kwargs = {"text": input_ids[0]}
|
953
940
|
else:
|
954
|
-
prompt_kwargs = {"input_ids": input_ids}
|
941
|
+
prompt_kwargs = {"input_ids": input_ids[0]}
|
955
942
|
sampling_params_list = sampling_params_list[0]
|
956
943
|
image_data_list = image_data_list[0]
|
957
944
|
return_logprobs = return_logprobs[0]
|
958
945
|
logprob_start_lens = logprob_start_lens[0]
|
959
946
|
top_logprobs_nums = top_logprobs_nums[0]
|
960
|
-
modalities_list = modalities_list[
|
947
|
+
modalities_list = modalities_list[0]
|
961
948
|
else:
|
962
949
|
if isinstance(input_ids[0], str):
|
963
950
|
prompt_kwargs = {"text": input_ids}
|
@@ -976,9 +963,8 @@ def v1_chat_generate_request(
|
|
976
963
|
rid=request_ids,
|
977
964
|
modalities=modalities_list,
|
978
965
|
)
|
979
|
-
|
980
|
-
|
981
|
-
return adapted_request, all_requests
|
966
|
+
|
967
|
+
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
982
968
|
|
983
969
|
|
984
970
|
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
@@ -1116,7 +1102,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1116
1102
|
async for content in tokenizer_manager.generate_request(
|
1117
1103
|
adapted_request, raw_request
|
1118
1104
|
):
|
1119
|
-
index = content
|
1105
|
+
index = content.get("index", 0)
|
1120
1106
|
|
1121
1107
|
is_first = is_firsts.get(index, True)
|
1122
1108
|
stream_buffer = stream_buffers.get(index, "")
|
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
import sglang.srt.sampling.penaltylib as penaltylib
|
9
|
-
from sglang.srt.constrained import
|
9
|
+
from sglang.srt.constrained.grammar import Grammar
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
12
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
@@ -29,11 +29,9 @@ class SamplingBatchInfo:
|
|
29
29
|
# Bias Tensors
|
30
30
|
vocab_size: int
|
31
31
|
logit_bias: torch.Tensor = None
|
32
|
-
vocab_mask: torch.Tensor = None
|
32
|
+
vocab_mask: Optional[torch.Tensor] = None
|
33
33
|
|
34
|
-
|
35
|
-
regex_fsms: List[RegexGuide] = None
|
36
|
-
regex_fsm_states: List[int] = None
|
34
|
+
grammars: Optional[List[Optional[Grammar]]] = None
|
37
35
|
|
38
36
|
# Penalizer
|
39
37
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
@@ -136,8 +134,7 @@ class SamplingBatchInfo:
|
|
136
134
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
137
135
|
|
138
136
|
def update_regex_vocab_mask(self):
|
139
|
-
|
140
|
-
if not has_regex:
|
137
|
+
if not self.grammars or not any(grammar for grammar in self.grammars):
|
141
138
|
self.vocab_mask = None
|
142
139
|
return
|
143
140
|
|
@@ -147,12 +144,9 @@ class SamplingBatchInfo:
|
|
147
144
|
dtype=torch.bool,
|
148
145
|
device=self.device,
|
149
146
|
)
|
150
|
-
for i,
|
151
|
-
if
|
152
|
-
self.vocab_mask[i].
|
153
|
-
self.vocab_mask[i][
|
154
|
-
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
155
|
-
] = 0
|
147
|
+
for i, grammar in enumerate(self.grammars):
|
148
|
+
if grammar is not None:
|
149
|
+
grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
|
156
150
|
|
157
151
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
158
152
|
if self.penalizer_orchestrator:
|
sglang/srt/server.py
CHANGED
@@ -53,7 +53,6 @@ from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
|
53
53
|
from sglang.srt.managers.io_struct import (
|
54
54
|
EmbeddingReqInput,
|
55
55
|
GenerateReqInput,
|
56
|
-
RewardReqInput,
|
57
56
|
UpdateWeightReqInput,
|
58
57
|
)
|
59
58
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
@@ -91,7 +90,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
91
90
|
|
92
91
|
|
93
92
|
app = FastAPI()
|
94
|
-
tokenizer_manager = None
|
93
|
+
tokenizer_manager: TokenizerManager = None
|
95
94
|
|
96
95
|
app.add_middleware(
|
97
96
|
CORSMiddleware,
|
@@ -139,7 +138,7 @@ async def get_server_args():
|
|
139
138
|
return dataclasses.asdict(tokenizer_manager.server_args)
|
140
139
|
|
141
140
|
|
142
|
-
@app.
|
141
|
+
@app.post("/flush_cache")
|
143
142
|
async def flush_cache():
|
144
143
|
"""Flush the radix cache."""
|
145
144
|
tokenizer_manager.flush_cache()
|
@@ -177,9 +176,10 @@ async def get_memory_pool_size():
|
|
177
176
|
"""Get the memory pool size in number of tokens"""
|
178
177
|
try:
|
179
178
|
ret = await tokenizer_manager.get_memory_pool_size()
|
180
|
-
|
179
|
+
|
180
|
+
return ret
|
181
181
|
except Exception as e:
|
182
|
-
return
|
182
|
+
return ORJSONResponse(
|
183
183
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
184
184
|
)
|
185
185
|
|
@@ -253,8 +253,8 @@ app.post("/encode")(encode_request)
|
|
253
253
|
app.put("/encode")(encode_request)
|
254
254
|
|
255
255
|
|
256
|
-
async def judge_request(obj:
|
257
|
-
"""Handle a reward model request."""
|
256
|
+
async def judge_request(obj: EmbeddingReqInput, request: Request):
|
257
|
+
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
258
258
|
try:
|
259
259
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
260
260
|
return ret
|
@@ -441,7 +441,7 @@ def launch_server(
|
|
441
441
|
|
442
442
|
# Send a warmup request
|
443
443
|
t = threading.Thread(
|
444
|
-
target=_wait_and_warmup, args=(server_args, pipe_finish_writer
|
444
|
+
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
445
445
|
)
|
446
446
|
t.start()
|
447
447
|
|
@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
496
496
|
mp.set_start_method("spawn", force=True)
|
497
497
|
|
498
498
|
|
499
|
-
def _wait_and_warmup(server_args, pipe_finish_writer
|
499
|
+
def _wait_and_warmup(server_args, pipe_finish_writer):
|
500
500
|
headers = {}
|
501
501
|
url = server_args.url()
|
502
502
|
if server_args.api_key:
|
@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
519
519
|
if pipe_finish_writer is not None:
|
520
520
|
pipe_finish_writer.send(last_traceback)
|
521
521
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
522
|
-
kill_child_process(
|
522
|
+
kill_child_process(include_self=True)
|
523
523
|
return
|
524
524
|
|
525
525
|
model_info = res.json()
|
@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
551
551
|
if pipe_finish_writer is not None:
|
552
552
|
pipe_finish_writer.send(last_traceback)
|
553
553
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
554
|
-
kill_child_process(
|
554
|
+
kill_child_process(include_self=True)
|
555
555
|
return
|
556
556
|
|
557
557
|
# logger.info(f"{res.json()=}")
|
@@ -617,7 +617,7 @@ class Runtime:
|
|
617
617
|
|
618
618
|
def shutdown(self):
|
619
619
|
if self.pid is not None:
|
620
|
-
kill_child_process(self.pid)
|
620
|
+
kill_child_process(self.pid, include_self=True)
|
621
621
|
self.pid = None
|
622
622
|
|
623
623
|
def cache_prefix(self, prefix: str):
|
@@ -696,24 +696,8 @@ class Runtime:
|
|
696
696
|
self,
|
697
697
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
698
698
|
):
|
699
|
-
|
700
|
-
|
701
|
-
json_data = {
|
702
|
-
"text": prompt,
|
703
|
-
}
|
704
|
-
response = requests.post(
|
705
|
-
self.url + "/encode",
|
706
|
-
json=json_data,
|
707
|
-
)
|
708
|
-
else:
|
709
|
-
# reward
|
710
|
-
json_data = {
|
711
|
-
"conv": prompt,
|
712
|
-
}
|
713
|
-
response = requests.post(
|
714
|
-
self.url + "/judge",
|
715
|
-
json=json_data,
|
716
|
-
)
|
699
|
+
json_data = {"text": prompt}
|
700
|
+
response = requests.post(self.url + "/encode", json=json_data)
|
717
701
|
return json.dumps(response.json())
|
718
702
|
|
719
703
|
def __del__(self):
|
@@ -736,24 +720,32 @@ class Engine:
|
|
736
720
|
|
737
721
|
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
738
722
|
atexit.register(self.shutdown)
|
723
|
+
|
724
|
+
# runtime server default log level is log
|
725
|
+
# offline engine works in scripts, so we set it to error
|
726
|
+
|
727
|
+
if 'log_level' not in kwargs:
|
728
|
+
kwargs['log_level'] = 'error'
|
739
729
|
|
740
730
|
server_args = ServerArgs(*args, **kwargs)
|
741
731
|
launch_engine(server_args=server_args)
|
742
732
|
|
743
733
|
def generate(
|
744
734
|
self,
|
745
|
-
prompt
|
735
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
736
|
+
prompt: Optional[Union[List[str], str]] = None,
|
746
737
|
sampling_params: Optional[Dict] = None,
|
738
|
+
# The token ids for text; one can either specify text or input_ids.
|
739
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
747
740
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
748
741
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
749
742
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
750
743
|
lora_path: Optional[List[Optional[str]]] = None,
|
751
744
|
stream: bool = False,
|
752
745
|
):
|
753
|
-
# TODO (ByronHsu): refactor to reduce the duplicated code
|
754
|
-
|
755
746
|
obj = GenerateReqInput(
|
756
747
|
text=prompt,
|
748
|
+
input_ids=input_ids,
|
757
749
|
sampling_params=sampling_params,
|
758
750
|
return_logprob=return_logprob,
|
759
751
|
logprob_start_len=logprob_start_len,
|
@@ -791,8 +783,11 @@ class Engine:
|
|
791
783
|
|
792
784
|
async def async_generate(
|
793
785
|
self,
|
794
|
-
prompt
|
786
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
787
|
+
prompt: Optional[Union[List[str], str]] = None,
|
795
788
|
sampling_params: Optional[Dict] = None,
|
789
|
+
# The token ids for text; one can either specify text or input_ids.
|
790
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
796
791
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
797
792
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
798
793
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
@@ -801,6 +796,7 @@ class Engine:
|
|
801
796
|
):
|
802
797
|
obj = GenerateReqInput(
|
803
798
|
text=prompt,
|
799
|
+
input_ids=input_ids,
|
804
800
|
sampling_params=sampling_params,
|
805
801
|
return_logprob=return_logprob,
|
806
802
|
logprob_start_len=logprob_start_len,
|
@@ -834,7 +830,7 @@ class Engine:
|
|
834
830
|
return ret
|
835
831
|
|
836
832
|
def shutdown(self):
|
837
|
-
kill_child_process(
|
833
|
+
kill_child_process()
|
838
834
|
|
839
835
|
def get_tokenizer(self):
|
840
836
|
global tokenizer_manager
|
sglang/srt/server_args.py
CHANGED
@@ -63,6 +63,7 @@ class ServerArgs:
|
|
63
63
|
stream_interval: int = 1
|
64
64
|
random_seed: Optional[int] = None
|
65
65
|
constrained_json_whitespace_pattern: Optional[str] = None
|
66
|
+
decode_log_interval: int = 40
|
66
67
|
|
67
68
|
# Logging
|
68
69
|
log_level: str = "info"
|
@@ -74,6 +75,7 @@ class ServerArgs:
|
|
74
75
|
api_key: Optional[str] = None
|
75
76
|
file_storage_pth: str = "SGLang_storage"
|
76
77
|
enable_cache_report: bool = False
|
78
|
+
watchdog_timeout: float = 600
|
77
79
|
|
78
80
|
# Data parallelism
|
79
81
|
dp_size: int = 1
|
@@ -102,6 +104,7 @@ class ServerArgs:
|
|
102
104
|
# Kernel backend
|
103
105
|
attention_backend: Optional[str] = None
|
104
106
|
sampling_backend: Optional[str] = None
|
107
|
+
grammar_backend: Optional[str] = "outlines"
|
105
108
|
|
106
109
|
# Optimization/debug options
|
107
110
|
disable_flashinfer: bool = False
|
@@ -118,7 +121,8 @@ class ServerArgs:
|
|
118
121
|
enable_overlap_schedule: bool = False
|
119
122
|
enable_mixed_chunk: bool = False
|
120
123
|
enable_torch_compile: bool = False
|
121
|
-
|
124
|
+
torch_compile_max_bs: int = 32
|
125
|
+
cuda_graph_max_bs: int = 160
|
122
126
|
torchao_config: str = ""
|
123
127
|
enable_p2p_check: bool = False
|
124
128
|
triton_attention_reduce_in_fp32: bool = False
|
@@ -427,6 +431,18 @@ class ServerArgs:
|
|
427
431
|
action="store_true",
|
428
432
|
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
429
433
|
)
|
434
|
+
parser.add_argument(
|
435
|
+
"--watchdog-timeout",
|
436
|
+
type=float,
|
437
|
+
default=ServerArgs.watchdog_timeout,
|
438
|
+
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
439
|
+
)
|
440
|
+
parser.add_argument(
|
441
|
+
"--decode-log-interval",
|
442
|
+
type=int,
|
443
|
+
default=ServerArgs.decode_log_interval,
|
444
|
+
help="The log interval of decode batch"
|
445
|
+
)
|
430
446
|
|
431
447
|
# Data parallelism
|
432
448
|
parser.add_argument(
|
@@ -537,6 +553,13 @@ class ServerArgs:
|
|
537
553
|
default=ServerArgs.sampling_backend,
|
538
554
|
help="Choose the kernels for sampling layers.",
|
539
555
|
)
|
556
|
+
parser.add_argument(
|
557
|
+
"--grammar-backend",
|
558
|
+
type=str,
|
559
|
+
choices=["xgrammar", "outlines"],
|
560
|
+
default=ServerArgs.grammar_backend,
|
561
|
+
help="Choose the backend for constrained decoding.",
|
562
|
+
)
|
540
563
|
|
541
564
|
# Optimization/debug options
|
542
565
|
parser.add_argument(
|
@@ -611,11 +634,17 @@ class ServerArgs:
|
|
611
634
|
help="Optimize the model with torch.compile. Experimental feature.",
|
612
635
|
)
|
613
636
|
parser.add_argument(
|
614
|
-
"--
|
637
|
+
"--torch-compile-max-bs",
|
615
638
|
type=int,
|
616
|
-
default=ServerArgs.
|
639
|
+
default=ServerArgs.torch_compile_max_bs,
|
617
640
|
help="Set the maximum batch size when using torch compile.",
|
618
641
|
)
|
642
|
+
parser.add_argument(
|
643
|
+
"--cuda-graph-max-bs",
|
644
|
+
type=int,
|
645
|
+
default=ServerArgs.cuda_graph_max_bs,
|
646
|
+
help="Set the maximum batch size for cuda graph.",
|
647
|
+
)
|
619
648
|
parser.add_argument(
|
620
649
|
"--torchao-config",
|
621
650
|
type=str,
|
@@ -712,11 +741,11 @@ class PortArgs:
|
|
712
741
|
|
713
742
|
@staticmethod
|
714
743
|
def init_new(server_args) -> "PortArgs":
|
715
|
-
port = server_args.port +
|
744
|
+
port = server_args.port + 42
|
716
745
|
while True:
|
717
746
|
if is_port_available(port):
|
718
747
|
break
|
719
|
-
port +=
|
748
|
+
port += 42
|
720
749
|
|
721
750
|
return PortArgs(
|
722
751
|
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
sglang/srt/utils.py
CHANGED
@@ -35,6 +35,7 @@ import psutil
|
|
35
35
|
import requests
|
36
36
|
import torch
|
37
37
|
import torch.distributed as dist
|
38
|
+
import zmq
|
38
39
|
from fastapi.responses import ORJSONResponse
|
39
40
|
from packaging import version as pkg_version
|
40
41
|
from torch import nn
|
@@ -203,56 +204,6 @@ def is_port_available(port):
|
|
203
204
|
return False
|
204
205
|
|
205
206
|
|
206
|
-
def is_multimodal_model(model_architectures):
|
207
|
-
if (
|
208
|
-
"LlavaLlamaForCausalLM" in model_architectures
|
209
|
-
or "LlavaQwenForCausalLM" in model_architectures
|
210
|
-
or "LlavaMistralForCausalLM" in model_architectures
|
211
|
-
or "LlavaVidForCausalLM" in model_architectures
|
212
|
-
or "MllamaForConditionalGeneration" in model_architectures
|
213
|
-
or "Qwen2VLForConditionalGeneration" in model_architectures
|
214
|
-
):
|
215
|
-
return True
|
216
|
-
else:
|
217
|
-
return False
|
218
|
-
|
219
|
-
|
220
|
-
def is_attention_free_model(model_architectures):
|
221
|
-
return False
|
222
|
-
|
223
|
-
|
224
|
-
def model_has_inner_state(model_architectures):
|
225
|
-
return False
|
226
|
-
|
227
|
-
|
228
|
-
def is_embedding_model(model_architectures):
|
229
|
-
if (
|
230
|
-
"LlamaEmbeddingModel" in model_architectures
|
231
|
-
or "MistralModel" in model_architectures
|
232
|
-
or "LlamaForSequenceClassification" in model_architectures
|
233
|
-
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
234
|
-
):
|
235
|
-
return True
|
236
|
-
else:
|
237
|
-
return False
|
238
|
-
|
239
|
-
|
240
|
-
def is_generation_model(model_architectures, is_embedding: bool = False):
|
241
|
-
# We have two ways to determine whether a model is a generative model.
|
242
|
-
# 1. Check the model architectue
|
243
|
-
# 2. check the `is_embedding` server args
|
244
|
-
|
245
|
-
if (
|
246
|
-
"LlamaEmbeddingModel" in model_architectures
|
247
|
-
or "MistralModel" in model_architectures
|
248
|
-
or "LlamaForSequenceClassification" in model_architectures
|
249
|
-
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
250
|
-
):
|
251
|
-
return False
|
252
|
-
else:
|
253
|
-
return not is_embedding
|
254
|
-
|
255
|
-
|
256
207
|
def decode_video_base64(video_base64):
|
257
208
|
from PIL import Image
|
258
209
|
|
@@ -397,17 +348,26 @@ def kill_parent_process():
|
|
397
348
|
"""Kill the parent process and all children of the parent process."""
|
398
349
|
current_process = psutil.Process()
|
399
350
|
parent_process = current_process.parent()
|
400
|
-
kill_child_process(
|
351
|
+
kill_child_process(
|
352
|
+
parent_process.pid, include_self=True, skip_pid=current_process.pid
|
353
|
+
)
|
354
|
+
try:
|
355
|
+
current_process.kill()
|
356
|
+
except psutil.NoSuchProcess:
|
357
|
+
pass
|
401
358
|
|
402
359
|
|
403
|
-
def kill_child_process(pid,
|
360
|
+
def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
404
361
|
"""Kill the process and all its children process."""
|
362
|
+
if pid is None:
|
363
|
+
pid = os.getpid()
|
364
|
+
|
405
365
|
try:
|
406
|
-
|
366
|
+
itself = psutil.Process(pid)
|
407
367
|
except psutil.NoSuchProcess:
|
408
368
|
return
|
409
369
|
|
410
|
-
children =
|
370
|
+
children = itself.children(recursive=True)
|
411
371
|
for child in children:
|
412
372
|
if child.pid == skip_pid:
|
413
373
|
continue
|
@@ -416,9 +376,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
|
|
416
376
|
except psutil.NoSuchProcess:
|
417
377
|
pass
|
418
378
|
|
419
|
-
if
|
379
|
+
if include_self:
|
420
380
|
try:
|
421
|
-
|
381
|
+
itself.kill()
|
422
382
|
except psutil.NoSuchProcess:
|
423
383
|
pass
|
424
384
|
|
@@ -720,3 +680,27 @@ def first_rank_print(*args, **kwargs):
|
|
720
680
|
print(*args, **kwargs)
|
721
681
|
else:
|
722
682
|
pass
|
683
|
+
|
684
|
+
|
685
|
+
def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
|
686
|
+
mem = psutil.virtual_memory()
|
687
|
+
total_mem = mem.total / 1024**3
|
688
|
+
available_mem = mem.available / 1024**3
|
689
|
+
if total_mem > 32 and available_mem > 16:
|
690
|
+
buf_size = int(0.5 * 1024**3)
|
691
|
+
else:
|
692
|
+
buf_size = -1
|
693
|
+
|
694
|
+
socket = context.socket(socket_type)
|
695
|
+
if socket_type == zmq.PUSH:
|
696
|
+
socket.setsockopt(zmq.SNDHWM, 0)
|
697
|
+
socket.setsockopt(zmq.SNDBUF, buf_size)
|
698
|
+
socket.connect(f"ipc://{endpoint}")
|
699
|
+
elif socket_type == zmq.PULL:
|
700
|
+
socket.setsockopt(zmq.RCVHWM, 0)
|
701
|
+
socket.setsockopt(zmq.RCVBUF, buf_size)
|
702
|
+
socket.bind(f"ipc://{endpoint}")
|
703
|
+
else:
|
704
|
+
raise ValueError(f"Unsupported socket type: {socket_type}")
|
705
|
+
|
706
|
+
return socket
|