sglang 0.3.4.post2__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +205 -3
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +54 -13
- sglang/srt/constrained/__init__.py +2 -48
- sglang/srt/constrained/base_grammar_backend.py +72 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +114 -0
- sglang/srt/hf_transformers_utils.py +6 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +117 -30
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +27 -10
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/quantization/base_config.py +14 -1
- sglang/srt/layers/vocab_parallel_embedding.py +552 -0
- sglang/srt/managers/data_parallel_controller.py +7 -6
- sglang/srt/managers/detokenizer_manager.py +9 -11
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +74 -80
- sglang/srt/managers/schedule_batch.py +35 -57
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +266 -150
- sglang/srt/managers/tokenizer_manager.py +292 -340
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +10 -18
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +281 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +25 -12
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +10 -26
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +9 -15
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +58 -68
- sglang/srt/sampling/sampling_batch_info.py +6 -13
- sglang/srt/sampling/sampling_params.py +0 -14
- sglang/srt/server.py +84 -46
- sglang/srt/server_args.py +61 -12
- sglang/srt/utils.py +127 -56
- sglang/test/runners.py +2 -1
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +89 -27
- sglang/utils.py +63 -1
- sglang/version.py +1 -1
- sglang-0.3.5.post1.dist-info/METADATA +348 -0
- sglang-0.3.5.post1.dist-info/RECORD +155 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/jump_forward.py +0 -203
- sglang-0.3.4.post2.dist-info/METADATA +0 -899
- sglang-0.3.4.post2.dist-info/RECORD +0 -148
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
sglang/api.py
CHANGED
@@ -99,7 +99,7 @@ def gen(
|
|
99
99
|
regex: Optional[str] = None,
|
100
100
|
json_schema: Optional[str] = None,
|
101
101
|
):
|
102
|
-
"""Call the model to generate. See the meaning of the arguments in docs/
|
102
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
103
103
|
|
104
104
|
if choices:
|
105
105
|
return SglSelect(
|
sglang/bench_latency.py
CHANGED
@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank):
|
|
129
129
|
|
130
130
|
model_config = ModelConfig(
|
131
131
|
server_args.model_path,
|
132
|
-
server_args.trust_remote_code,
|
132
|
+
trust_remote_code=server_args.trust_remote_code,
|
133
133
|
context_length=server_args.context_length,
|
134
|
-
model_override_args=
|
134
|
+
model_override_args=server_args.json_model_override_args,
|
135
135
|
)
|
136
136
|
model_runner = ModelRunner(
|
137
137
|
model_config=model_config,
|
@@ -550,4 +550,4 @@ if __name__ == "__main__":
|
|
550
550
|
except Exception as e:
|
551
551
|
raise e
|
552
552
|
finally:
|
553
|
-
kill_child_process(
|
553
|
+
kill_child_process()
|
sglang/bench_server_latency.py
CHANGED
@@ -15,7 +15,6 @@ import dataclasses
|
|
15
15
|
import itertools
|
16
16
|
import json
|
17
17
|
import multiprocessing
|
18
|
-
import os
|
19
18
|
import time
|
20
19
|
from typing import Tuple
|
21
20
|
|
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
|
|
70
69
|
except Exception as e:
|
71
70
|
raise e
|
72
71
|
finally:
|
73
|
-
kill_child_process(
|
72
|
+
kill_child_process()
|
74
73
|
|
75
74
|
|
76
75
|
def launch_server_process(server_args: ServerArgs):
|
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
176
175
|
)
|
177
176
|
finally:
|
178
177
|
if proc:
|
179
|
-
kill_child_process(proc.pid)
|
178
|
+
kill_child_process(proc.pid, include_self=True)
|
180
179
|
|
181
180
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
182
181
|
|
sglang/bench_serving.py
CHANGED
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
|
|
222
222
|
return output
|
223
223
|
|
224
224
|
|
225
|
+
async def async_request_truss(
|
226
|
+
request_func_input: RequestFuncInput,
|
227
|
+
pbar: Optional[tqdm] = None,
|
228
|
+
) -> RequestFuncOutput:
|
229
|
+
api_url = request_func_input.api_url
|
230
|
+
|
231
|
+
prompt = request_func_input.prompt
|
232
|
+
|
233
|
+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
234
|
+
payload = {
|
235
|
+
"model": request_func_input.model,
|
236
|
+
"prompt": prompt,
|
237
|
+
"temperature": 0.0,
|
238
|
+
"best_of": 1,
|
239
|
+
"max_tokens": request_func_input.output_len,
|
240
|
+
"stream": not args.disable_stream,
|
241
|
+
"ignore_eos": not args.disable_ignore_eos,
|
242
|
+
**request_func_input.extra_request_body,
|
243
|
+
}
|
244
|
+
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
245
|
+
|
246
|
+
output = RequestFuncOutput()
|
247
|
+
output.prompt_len = request_func_input.prompt_len
|
248
|
+
|
249
|
+
generated_text = ""
|
250
|
+
ttft = 0.0
|
251
|
+
st = time.perf_counter()
|
252
|
+
most_recent_timestamp = st
|
253
|
+
try:
|
254
|
+
async with session.post(
|
255
|
+
url=api_url, json=payload, headers=headers
|
256
|
+
) as response:
|
257
|
+
if response.status == 200:
|
258
|
+
async for chunk_bytes in response.content:
|
259
|
+
chunk_bytes = chunk_bytes.strip()
|
260
|
+
if not chunk_bytes:
|
261
|
+
continue
|
262
|
+
|
263
|
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
264
|
+
latency = time.perf_counter() - st
|
265
|
+
if chunk == "[DONE]":
|
266
|
+
pass
|
267
|
+
else:
|
268
|
+
data = json.loads(chunk)
|
269
|
+
|
270
|
+
# NOTE: Some completion API might have a last
|
271
|
+
# usage summary response without a token so we
|
272
|
+
# want to check a token was generated
|
273
|
+
if data["choices"][0]["delta"]["content"]:
|
274
|
+
timestamp = time.perf_counter()
|
275
|
+
# First token
|
276
|
+
if ttft == 0.0:
|
277
|
+
ttft = time.perf_counter() - st
|
278
|
+
output.ttft = ttft
|
279
|
+
|
280
|
+
# Decoding phase
|
281
|
+
else:
|
282
|
+
output.itl.append(timestamp - most_recent_timestamp)
|
283
|
+
|
284
|
+
most_recent_timestamp = timestamp
|
285
|
+
generated_text += data["choices"][0]["delta"]["content"]
|
286
|
+
|
287
|
+
output.generated_text = generated_text
|
288
|
+
output.success = True
|
289
|
+
output.latency = latency
|
290
|
+
output.output_len = request_func_input.output_len
|
291
|
+
else:
|
292
|
+
output.error = response.reason or ""
|
293
|
+
output.success = False
|
294
|
+
except Exception:
|
295
|
+
output.success = False
|
296
|
+
exc_info = sys.exc_info()
|
297
|
+
output.error = "".join(traceback.format_exception(*exc_info))
|
298
|
+
|
299
|
+
if pbar:
|
300
|
+
pbar.update(1)
|
301
|
+
return output
|
302
|
+
|
303
|
+
|
225
304
|
async def async_request_sglang_generate(
|
226
305
|
request_func_input: RequestFuncInput,
|
227
306
|
pbar: Optional[tqdm] = None,
|
@@ -350,6 +429,7 @@ ASYNC_REQUEST_FUNCS = {
|
|
350
429
|
"lmdeploy": async_request_openai_completions,
|
351
430
|
"trt": async_request_trt_llm,
|
352
431
|
"gserver": async_request_gserver,
|
432
|
+
"truss": async_request_truss,
|
353
433
|
}
|
354
434
|
|
355
435
|
|
@@ -516,12 +596,20 @@ def sample_random_requests(
|
|
516
596
|
|
517
597
|
# Filter out sequences that are too long or too short
|
518
598
|
input_requests: List[Tuple[str, int, int]] = []
|
519
|
-
for
|
599
|
+
for data in dataset:
|
600
|
+
i = len(input_requests)
|
601
|
+
if i == num_prompts:
|
602
|
+
break
|
603
|
+
|
520
604
|
# Tokenize the prompts and completions.
|
521
|
-
prompt =
|
605
|
+
prompt = data[0]
|
522
606
|
prompt_token_ids = tokenizer.encode(prompt)
|
523
607
|
prompt_len = len(prompt_token_ids)
|
524
608
|
|
609
|
+
# Skip empty prompt
|
610
|
+
if prompt_len == 0:
|
611
|
+
continue
|
612
|
+
|
525
613
|
if prompt_len > input_lens[i]:
|
526
614
|
input_ids = prompt_token_ids[: input_lens[i]]
|
527
615
|
else:
|
@@ -547,6 +635,66 @@ def sample_random_requests(
|
|
547
635
|
return input_requests
|
548
636
|
|
549
637
|
|
638
|
+
def gen_prompt(tokenizer, token_num):
|
639
|
+
"""Generate a random prompt of specified token length using tokenizer vocabulary."""
|
640
|
+
all_available_tokens = list(tokenizer.get_vocab().values())
|
641
|
+
selected_tokens = random.choices(all_available_tokens, k=token_num)
|
642
|
+
return tokenizer.decode(selected_tokens)
|
643
|
+
|
644
|
+
|
645
|
+
def sample_generated_shared_prefix_requests(
|
646
|
+
num_groups: int,
|
647
|
+
prompts_per_group: int,
|
648
|
+
system_prompt_len: int,
|
649
|
+
question_len: int,
|
650
|
+
output_len: int,
|
651
|
+
tokenizer: PreTrainedTokenizerBase,
|
652
|
+
) -> List[Tuple[str, int, int]]:
|
653
|
+
"""Generate benchmark requests with shared system prompts using random tokens."""
|
654
|
+
# Generate system prompts for each group
|
655
|
+
system_prompts = []
|
656
|
+
for _ in range(num_groups):
|
657
|
+
system_prompt = gen_prompt(tokenizer, system_prompt_len)
|
658
|
+
system_prompts.append(system_prompt)
|
659
|
+
|
660
|
+
# Generate questions
|
661
|
+
questions = []
|
662
|
+
for _ in range(num_groups * prompts_per_group):
|
663
|
+
question = gen_prompt(tokenizer, question_len)
|
664
|
+
questions.append(question)
|
665
|
+
|
666
|
+
# Combine system prompts with questions
|
667
|
+
input_requests = []
|
668
|
+
total_input_tokens = 0
|
669
|
+
total_output_tokens = 0
|
670
|
+
|
671
|
+
for group_idx in range(num_groups):
|
672
|
+
system_prompt = system_prompts[group_idx]
|
673
|
+
for prompt_idx in range(prompts_per_group):
|
674
|
+
question = questions[group_idx * prompts_per_group + prompt_idx]
|
675
|
+
full_prompt = f"{system_prompt}\n\n{question}"
|
676
|
+
prompt_len = len(tokenizer.encode(full_prompt))
|
677
|
+
|
678
|
+
input_requests.append((full_prompt, prompt_len, output_len))
|
679
|
+
total_input_tokens += prompt_len
|
680
|
+
total_output_tokens += output_len
|
681
|
+
|
682
|
+
print(f"\nGenerated shared prefix dataset statistics:")
|
683
|
+
print(f"Number of groups: {num_groups}")
|
684
|
+
print(f"Prompts per group: {prompts_per_group}")
|
685
|
+
print(f"Total prompts: {len(input_requests)}")
|
686
|
+
print(f"Total input tokens: {total_input_tokens}")
|
687
|
+
print(f"Total output tokens: {total_output_tokens}")
|
688
|
+
print(
|
689
|
+
f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
|
690
|
+
)
|
691
|
+
print(
|
692
|
+
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
693
|
+
)
|
694
|
+
|
695
|
+
return input_requests
|
696
|
+
|
697
|
+
|
550
698
|
async def get_request(
|
551
699
|
input_requests: List[Tuple[str, int, int]],
|
552
700
|
request_rate: float,
|
@@ -873,6 +1021,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
873
1021
|
"vllm": 8000,
|
874
1022
|
"trt": 8000,
|
875
1023
|
"gserver": 9988,
|
1024
|
+
"truss": 8080,
|
876
1025
|
}.get(args.backend, 30000)
|
877
1026
|
|
878
1027
|
model_url = (
|
@@ -905,9 +1054,20 @@ def run_benchmark(args_: argparse.Namespace):
|
|
905
1054
|
elif args.backend == "gserver":
|
906
1055
|
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
|
907
1056
|
args.model = args.model or "default"
|
1057
|
+
elif args.backend == "truss":
|
1058
|
+
api_url = (
|
1059
|
+
f"{args.base_url}/v1/models/model:predict"
|
1060
|
+
if args.base_url
|
1061
|
+
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
1062
|
+
)
|
908
1063
|
|
909
1064
|
# Get model name
|
910
1065
|
if args.model is None:
|
1066
|
+
if args.backend == "truss":
|
1067
|
+
print(
|
1068
|
+
"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
|
1069
|
+
)
|
1070
|
+
sys.exit(1)
|
911
1071
|
try:
|
912
1072
|
response = requests.get(model_url)
|
913
1073
|
model_list = response.json().get("data", [])
|
@@ -956,6 +1116,15 @@ def run_benchmark(args_: argparse.Namespace):
|
|
956
1116
|
tokenizer=tokenizer,
|
957
1117
|
dataset_path=args.dataset_path,
|
958
1118
|
)
|
1119
|
+
elif args.dataset_name == "generated-shared-prefix":
|
1120
|
+
input_requests = sample_generated_shared_prefix_requests(
|
1121
|
+
num_groups=args.gen_num_groups,
|
1122
|
+
prompts_per_group=args.gen_prompts_per_group,
|
1123
|
+
system_prompt_len=args.gen_system_prompt_len,
|
1124
|
+
question_len=args.gen_question_len,
|
1125
|
+
output_len=args.gen_output_len,
|
1126
|
+
tokenizer=tokenizer,
|
1127
|
+
)
|
959
1128
|
else:
|
960
1129
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
961
1130
|
|
@@ -1029,7 +1198,7 @@ if __name__ == "__main__":
|
|
1029
1198
|
"--dataset-name",
|
1030
1199
|
type=str,
|
1031
1200
|
default="sharegpt",
|
1032
|
-
choices=["sharegpt", "random"],
|
1201
|
+
choices=["sharegpt", "random", "generated-shared-prefix"],
|
1033
1202
|
help="Name of the dataset to benchmark on.",
|
1034
1203
|
)
|
1035
1204
|
parser.add_argument(
|
@@ -1116,5 +1285,38 @@ if __name__ == "__main__":
|
|
1116
1285
|
help="Append given JSON object to the request payload. You can use this to specify"
|
1117
1286
|
"additional generate params like sampling params.",
|
1118
1287
|
)
|
1288
|
+
|
1289
|
+
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
1290
|
+
group.add_argument(
|
1291
|
+
"--gen-num-groups",
|
1292
|
+
type=int,
|
1293
|
+
default=64,
|
1294
|
+
help="Number of system prompt groups for generated-shared-prefix dataset",
|
1295
|
+
)
|
1296
|
+
group.add_argument(
|
1297
|
+
"--gen-prompts-per-group",
|
1298
|
+
type=int,
|
1299
|
+
default=16,
|
1300
|
+
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
1301
|
+
)
|
1302
|
+
group.add_argument(
|
1303
|
+
"--gen-system-prompt-len",
|
1304
|
+
type=int,
|
1305
|
+
default=2048,
|
1306
|
+
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
1307
|
+
)
|
1308
|
+
group.add_argument(
|
1309
|
+
"--gen-question-len",
|
1310
|
+
type=int,
|
1311
|
+
default=128,
|
1312
|
+
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
1313
|
+
)
|
1314
|
+
group.add_argument(
|
1315
|
+
"--gen-output-len",
|
1316
|
+
type=int,
|
1317
|
+
default=256,
|
1318
|
+
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
1319
|
+
)
|
1320
|
+
|
1119
1321
|
args = parser.parse_args()
|
1120
1322
|
run_benchmark(args)
|
sglang/global_config.py
CHANGED
@@ -14,9 +14,15 @@ class GlobalConfig:
|
|
14
14
|
self.default_backend = None
|
15
15
|
|
16
16
|
# Runtime constants: New generation token ratio estimation
|
17
|
-
self.
|
18
|
-
|
19
|
-
|
17
|
+
self.default_init_new_token_ratio = float(
|
18
|
+
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
|
19
|
+
)
|
20
|
+
self.default_min_new_token_ratio_factor = float(
|
21
|
+
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
|
22
|
+
)
|
23
|
+
self.default_new_token_ratio_decay_steps = float(
|
24
|
+
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
|
25
|
+
)
|
20
26
|
|
21
27
|
# Runtime constants: others
|
22
28
|
self.retract_decode_steps = 20
|
sglang/lang/chat_template.py
CHANGED
@@ -116,12 +116,10 @@ register_chat_template(
|
|
116
116
|
)
|
117
117
|
)
|
118
118
|
|
119
|
-
|
120
|
-
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
121
|
-
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
119
|
+
|
122
120
|
register_chat_template(
|
123
121
|
ChatTemplate(
|
124
|
-
name="
|
122
|
+
name="chatml-llava",
|
125
123
|
default_system_prompt="You are a helpful assistant.",
|
126
124
|
role_prefix_and_suffix={
|
127
125
|
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
@@ -130,13 +128,17 @@ register_chat_template(
|
|
130
128
|
},
|
131
129
|
style=ChatTemplateStyle.PLAIN,
|
132
130
|
stop_str=("<|im_end|>",),
|
131
|
+
image_token="<image>\n",
|
133
132
|
)
|
134
133
|
)
|
135
134
|
|
136
|
-
|
135
|
+
|
136
|
+
# There is default system prompt for qwen
|
137
|
+
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
138
|
+
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
137
139
|
register_chat_template(
|
138
140
|
ChatTemplate(
|
139
|
-
name="
|
141
|
+
name="qwen",
|
140
142
|
default_system_prompt="You are a helpful assistant.",
|
141
143
|
role_prefix_and_suffix={
|
142
144
|
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
@@ -144,15 +146,14 @@ register_chat_template(
|
|
144
146
|
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
145
147
|
},
|
146
148
|
style=ChatTemplateStyle.PLAIN,
|
147
|
-
stop_str=("<|im_end|>"),
|
148
|
-
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
149
|
+
stop_str=("<|im_end|>",),
|
149
150
|
)
|
150
151
|
)
|
151
152
|
|
152
|
-
|
153
|
+
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
153
154
|
register_chat_template(
|
154
155
|
ChatTemplate(
|
155
|
-
name="
|
156
|
+
name="qwen2-vl",
|
156
157
|
default_system_prompt="You are a helpful assistant.",
|
157
158
|
role_prefix_and_suffix={
|
158
159
|
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
@@ -161,7 +162,7 @@ register_chat_template(
|
|
161
162
|
},
|
162
163
|
style=ChatTemplateStyle.PLAIN,
|
163
164
|
stop_str=("<|im_end|>",),
|
164
|
-
image_token="
|
165
|
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
165
166
|
)
|
166
167
|
)
|
167
168
|
|
@@ -182,37 +183,46 @@ register_chat_template(
|
|
182
183
|
)
|
183
184
|
)
|
184
185
|
|
185
|
-
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
186
186
|
register_chat_template(
|
187
187
|
ChatTemplate(
|
188
|
-
name="
|
188
|
+
name="llama-2-chat",
|
189
189
|
default_system_prompt=None,
|
190
190
|
role_prefix_and_suffix={
|
191
|
-
"system": ("", ""),
|
192
|
-
"user": ("
|
193
|
-
"assistant": ("", "
|
191
|
+
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
|
192
|
+
"user": ("[INST] ", " [/INST]"),
|
193
|
+
"assistant": ("", " </s><s>"),
|
194
194
|
},
|
195
|
-
style=ChatTemplateStyle.
|
196
|
-
stop_str=("<|im_end|>",),
|
195
|
+
style=ChatTemplateStyle.LLAMA2,
|
197
196
|
)
|
198
197
|
)
|
199
198
|
|
200
199
|
register_chat_template(
|
201
200
|
ChatTemplate(
|
202
|
-
name="llama-
|
201
|
+
name="llama-3-instruct",
|
203
202
|
default_system_prompt=None,
|
204
203
|
role_prefix_and_suffix={
|
205
|
-
"system": (
|
206
|
-
|
207
|
-
|
204
|
+
"system": (
|
205
|
+
"<|start_header_id|>system<|end_header_id|>\n\n",
|
206
|
+
"<|eot_id|>",
|
207
|
+
),
|
208
|
+
"user": (
|
209
|
+
"<|start_header_id|>user<|end_header_id|>\n\n",
|
210
|
+
"<|eot_id|>",
|
211
|
+
),
|
212
|
+
"assistant": (
|
213
|
+
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
214
|
+
"<|eot_id|>",
|
215
|
+
),
|
208
216
|
},
|
209
|
-
|
217
|
+
stop_str=("<|eot_id|>",),
|
218
|
+
image_token="<|image|>",
|
210
219
|
)
|
211
220
|
)
|
212
221
|
|
222
|
+
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
213
223
|
register_chat_template(
|
214
224
|
ChatTemplate(
|
215
|
-
name="llama-3-instruct",
|
225
|
+
name="llama-3-instruct-llava",
|
216
226
|
default_system_prompt=None,
|
217
227
|
role_prefix_and_suffix={
|
218
228
|
"system": (
|
@@ -229,7 +239,22 @@ register_chat_template(
|
|
229
239
|
),
|
230
240
|
},
|
231
241
|
stop_str=("<|eot_id|>",),
|
232
|
-
image_token="
|
242
|
+
image_token="<image>\n",
|
243
|
+
)
|
244
|
+
)
|
245
|
+
|
246
|
+
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
247
|
+
register_chat_template(
|
248
|
+
ChatTemplate(
|
249
|
+
name="yi-1.5",
|
250
|
+
default_system_prompt=None,
|
251
|
+
role_prefix_and_suffix={
|
252
|
+
"system": ("", ""),
|
253
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
254
|
+
"assistant": ("", "<|im_end|>\n"),
|
255
|
+
},
|
256
|
+
style=ChatTemplateStyle.PLAIN,
|
257
|
+
stop_str=("<|im_end|>",),
|
233
258
|
)
|
234
259
|
)
|
235
260
|
|
sglang/lang/interpreter.py
CHANGED
@@ -54,7 +54,14 @@ def run_internal(state, program, func_args, func_kwargs, sync):
|
|
54
54
|
|
55
55
|
|
56
56
|
def run_program(
|
57
|
-
program,
|
57
|
+
program,
|
58
|
+
backend,
|
59
|
+
func_args,
|
60
|
+
func_kwargs,
|
61
|
+
default_sampling_para,
|
62
|
+
stream,
|
63
|
+
sync=False,
|
64
|
+
use_thread=True,
|
58
65
|
):
|
59
66
|
if hasattr(backend, "endpoint"):
|
60
67
|
backend = backend.endpoint
|
@@ -67,6 +74,7 @@ def run_program(
|
|
67
74
|
chat_template=None,
|
68
75
|
stream=stream,
|
69
76
|
num_api_spec_tokens=program.num_api_spec_tokens,
|
77
|
+
use_thread=use_thread,
|
70
78
|
)
|
71
79
|
state = ProgramState(stream_executor)
|
72
80
|
|
sglang/lang/ir.py
CHANGED
@@ -168,6 +168,7 @@ class SglFunction:
|
|
168
168
|
return_text_in_logprobs: Optional[bool] = None,
|
169
169
|
stream: bool = False,
|
170
170
|
backend=None,
|
171
|
+
use_thread: bool = True,
|
171
172
|
**kwargs,
|
172
173
|
):
|
173
174
|
from sglang.lang.interpreter import run_program
|
@@ -195,7 +196,15 @@ class SglFunction:
|
|
195
196
|
return_text_in_logprobs=return_text_in_logprobs,
|
196
197
|
)
|
197
198
|
backend = backend or global_config.default_backend
|
198
|
-
return run_program(
|
199
|
+
return run_program(
|
200
|
+
self,
|
201
|
+
backend,
|
202
|
+
args,
|
203
|
+
kwargs,
|
204
|
+
default_sampling_para,
|
205
|
+
stream,
|
206
|
+
use_thread=use_thread,
|
207
|
+
)
|
199
208
|
|
200
209
|
def run_batch(
|
201
210
|
self,
|
@@ -445,7 +454,7 @@ class SglGen(SglExpr):
|
|
445
454
|
regex: Optional[str] = None,
|
446
455
|
json_schema: Optional[str] = None,
|
447
456
|
):
|
448
|
-
"""Call the model to generate. See the meaning of the arguments in docs/
|
457
|
+
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
449
458
|
super().__init__()
|
450
459
|
self.name = name
|
451
460
|
self.sampling_params = SglSamplingParams(
|
sglang/launch_server.py
CHANGED
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
import json
|
16
17
|
import logging
|
17
18
|
import os
|
18
19
|
from enum import IntEnum, auto
|
19
|
-
from typing import Optional
|
20
|
+
from typing import List, Optional
|
20
21
|
|
21
22
|
from transformers import PretrainedConfig
|
22
23
|
|
@@ -38,18 +39,26 @@ class ModelConfig:
|
|
38
39
|
revision: Optional[str] = None,
|
39
40
|
context_length: Optional[int] = None,
|
40
41
|
model_override_args: Optional[dict] = None,
|
42
|
+
is_embedding: Optional[bool] = None,
|
41
43
|
) -> None:
|
42
|
-
|
43
|
-
self.
|
44
|
-
self.revision = revision
|
45
|
-
self.model_override_args = model_override_args
|
44
|
+
# Parse args
|
45
|
+
self.model_override_args = json.loads(model_override_args)
|
46
46
|
self.hf_config = get_config(
|
47
|
-
|
48
|
-
trust_remote_code,
|
49
|
-
revision,
|
50
|
-
model_override_args=model_override_args,
|
47
|
+
path,
|
48
|
+
trust_remote_code=trust_remote_code,
|
49
|
+
revision=revision,
|
50
|
+
model_override_args=self.model_override_args,
|
51
51
|
)
|
52
52
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
53
|
+
|
54
|
+
# Check model type
|
55
|
+
self.is_generation = is_generation_model(
|
56
|
+
self.hf_config.architectures, is_embedding
|
57
|
+
)
|
58
|
+
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
59
|
+
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
60
|
+
|
61
|
+
# Derive context length
|
53
62
|
derived_context_len = get_context_length(self.hf_text_config)
|
54
63
|
allow_long_context = os.environ.get(
|
55
64
|
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
@@ -81,7 +90,7 @@ class ModelConfig:
|
|
81
90
|
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
82
91
|
)
|
83
92
|
|
84
|
-
# FIXME: temporary special judge for
|
93
|
+
# FIXME: temporary special judge for MLA architecture
|
85
94
|
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
86
95
|
self.head_dim = 256
|
87
96
|
self.attention_arch = AttentionArch.MLA
|
@@ -112,8 +121,6 @@ class ModelConfig:
|
|
112
121
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
113
122
|
self.vocab_size = self.hf_text_config.vocab_size
|
114
123
|
|
115
|
-
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
|
116
|
-
|
117
124
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
118
125
|
def get_total_num_kv_heads(self) -> int:
|
119
126
|
"""Returns the total number of KV heads."""
|
@@ -163,7 +170,6 @@ class ModelConfig:
|
|
163
170
|
# equal to the number of attention heads.
|
164
171
|
return self.hf_text_config.num_attention_heads
|
165
172
|
|
166
|
-
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
|
167
173
|
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
168
174
|
"""Returns the number of KV heads per GPU."""
|
169
175
|
total_num_kv_heads = self.get_total_num_kv_heads()
|
@@ -192,3 +198,38 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
192
198
|
return config.text_config
|
193
199
|
else:
|
194
200
|
return config
|
201
|
+
|
202
|
+
|
203
|
+
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
204
|
+
# We have two ways to determine whether a model is a generative model.
|
205
|
+
# 1. Check the model architectue
|
206
|
+
# 2. check the `is_embedding` server args
|
207
|
+
|
208
|
+
if (
|
209
|
+
"LlamaEmbeddingModel" in model_architectures
|
210
|
+
or "MistralModel" in model_architectures
|
211
|
+
or "LlamaForSequenceClassification" in model_architectures
|
212
|
+
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
213
|
+
or "InternLM2ForRewardModel" in model_architectures
|
214
|
+
):
|
215
|
+
return False
|
216
|
+
else:
|
217
|
+
return not is_embedding
|
218
|
+
|
219
|
+
|
220
|
+
def is_multimodal_model(model_architectures: List[str]):
|
221
|
+
if (
|
222
|
+
"LlavaLlamaForCausalLM" in model_architectures
|
223
|
+
or "LlavaQwenForCausalLM" in model_architectures
|
224
|
+
or "LlavaMistralForCausalLM" in model_architectures
|
225
|
+
or "LlavaVidForCausalLM" in model_architectures
|
226
|
+
or "MllamaForConditionalGeneration" in model_architectures
|
227
|
+
or "Qwen2VLForConditionalGeneration" in model_architectures
|
228
|
+
):
|
229
|
+
return True
|
230
|
+
else:
|
231
|
+
return False
|
232
|
+
|
233
|
+
|
234
|
+
def is_encoder_decoder_model(model_architectures: List[str]):
|
235
|
+
return "MllamaForConditionalGeneration" in model_architectures
|