sglang 0.1.22__py3-none-any.whl → 0.1.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/bench_serving.py +243 -25
- sglang/global_config.py +3 -2
- sglang/lang/interpreter.py +1 -0
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/radix_attention.py +38 -49
- sglang/srt/managers/controller/cuda_graph_runner.py +58 -16
- sglang/srt/managers/controller/infer_batch.py +51 -22
- sglang/srt/managers/controller/model_runner.py +58 -4
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +9 -11
- sglang/srt/memory_pool.py +13 -5
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/llama2.py +19 -10
- sglang/srt/server.py +26 -1
- sglang/srt/server_args.py +12 -6
- sglang/srt/utils.py +93 -1
- sglang/version.py +1 -0
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/METADATA +10 -6
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/RECORD +25 -36
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/WHEEL +1 -1
- sglang/backend/__init__.py +0 -0
- sglang/backend/anthropic.py +0 -77
- sglang/backend/base_backend.py +0 -80
- sglang/backend/litellm.py +0 -90
- sglang/backend/openai.py +0 -438
- sglang/backend/runtime_endpoint.py +0 -283
- sglang/backend/vertexai.py +0 -149
- sglang/bench.py +0 -627
- sglang/srt/managers/controller/dp_worker.py +0 -113
- sglang/srt/openai_api/api_adapter.py +0 -432
- sglang/srt/openai_api/openai_api_adapter.py +0 -431
- sglang/srt/openai_api/openai_protocol.py +0 -207
- sglang/srt/openai_api_adapter.py +0 -411
- sglang/srt/openai_protocol.py +0 -207
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/LICENSE +0 -0
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
__version__ = "0.1.22"
|
2
|
-
|
3
1
|
# SGL API Components
|
4
2
|
from sglang.api import (
|
5
3
|
Runtime,
|
@@ -32,6 +30,8 @@ from sglang.lang.backend.openai import OpenAI
|
|
32
30
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
33
31
|
from sglang.lang.backend.vertexai import VertexAI
|
34
32
|
|
33
|
+
from .version import __version__
|
34
|
+
|
35
35
|
# public APIs management
|
36
36
|
__all__ = [
|
37
37
|
"global_config",
|
sglang/bench_serving.py
CHANGED
@@ -5,6 +5,9 @@ Benchmark online serving.
|
|
5
5
|
|
6
6
|
Usage:
|
7
7
|
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
8
|
+
|
9
|
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
|
10
|
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
|
8
11
|
"""
|
9
12
|
|
10
13
|
import argparse
|
@@ -19,6 +22,7 @@ import traceback
|
|
19
22
|
import warnings
|
20
23
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
21
24
|
from dataclasses import dataclass, field
|
25
|
+
from datetime import datetime
|
22
26
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
23
27
|
|
24
28
|
import aiohttp
|
@@ -53,12 +57,80 @@ class RequestFuncOutput:
|
|
53
57
|
itl: List[float] = field(default_factory=list) # List of inter-token latencies
|
54
58
|
prompt_len: int = 0
|
55
59
|
error: str = ""
|
60
|
+
output_len: int = 0
|
56
61
|
|
57
62
|
|
58
63
|
def remove_prefix(text: str, prefix: str) -> str:
|
59
64
|
return text[len(prefix) :] if text.startswith(prefix) else text
|
60
65
|
|
61
66
|
|
67
|
+
# trt llm not support ignore_eos
|
68
|
+
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
69
|
+
async def async_request_trt_llm(
|
70
|
+
request_func_input: RequestFuncInput,
|
71
|
+
pbar: Optional[tqdm] = None,
|
72
|
+
) -> RequestFuncOutput:
|
73
|
+
api_url = request_func_input.api_url
|
74
|
+
assert api_url.endswith("generate_stream")
|
75
|
+
|
76
|
+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
77
|
+
payload = {
|
78
|
+
"accumulate_tokens": True,
|
79
|
+
"text_input": request_func_input.prompt,
|
80
|
+
"temperature": 0.000001,
|
81
|
+
"top_p": 1.0,
|
82
|
+
"max_tokens": request_func_input.output_len,
|
83
|
+
"stream": True,
|
84
|
+
"min_length": request_func_input.output_len,
|
85
|
+
"end_id": 1048576,
|
86
|
+
}
|
87
|
+
output = RequestFuncOutput()
|
88
|
+
output.prompt_len = request_func_input.prompt_len
|
89
|
+
|
90
|
+
ttft = 0.0
|
91
|
+
st = time.perf_counter()
|
92
|
+
most_recent_timestamp = st
|
93
|
+
try:
|
94
|
+
async with session.post(url=api_url, json=payload) as response:
|
95
|
+
if response.status == 200:
|
96
|
+
async for chunk_bytes in response.content:
|
97
|
+
chunk_bytes = chunk_bytes.strip()
|
98
|
+
if not chunk_bytes:
|
99
|
+
continue
|
100
|
+
|
101
|
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
|
102
|
+
|
103
|
+
data = json.loads(chunk)
|
104
|
+
output.generated_text += data["text_output"]
|
105
|
+
timestamp = time.perf_counter()
|
106
|
+
# First token
|
107
|
+
if ttft == 0.0:
|
108
|
+
ttft = time.perf_counter() - st
|
109
|
+
output.ttft = ttft
|
110
|
+
|
111
|
+
# Decoding phase
|
112
|
+
else:
|
113
|
+
output.itl.append(timestamp - most_recent_timestamp)
|
114
|
+
|
115
|
+
most_recent_timestamp = timestamp
|
116
|
+
|
117
|
+
output.latency = most_recent_timestamp - st
|
118
|
+
output.success = True
|
119
|
+
output.output_len = request_func_input.output_len
|
120
|
+
|
121
|
+
else:
|
122
|
+
output.error = response.reason or ""
|
123
|
+
output.success = False
|
124
|
+
except Exception:
|
125
|
+
output.success = False
|
126
|
+
exc_info = sys.exc_info()
|
127
|
+
output.error = "".join(traceback.format_exception(*exc_info))
|
128
|
+
|
129
|
+
if pbar:
|
130
|
+
pbar.update(1)
|
131
|
+
return output
|
132
|
+
|
133
|
+
|
62
134
|
# set ignore_eos True by default
|
63
135
|
async def async_request_openai_completions(
|
64
136
|
request_func_input: RequestFuncInput,
|
@@ -76,7 +148,7 @@ async def async_request_openai_completions(
|
|
76
148
|
"temperature": 0.0,
|
77
149
|
"best_of": 1,
|
78
150
|
"max_tokens": request_func_input.output_len,
|
79
|
-
"stream":
|
151
|
+
"stream": not args.disable_stream,
|
80
152
|
"ignore_eos": True,
|
81
153
|
}
|
82
154
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
@@ -99,8 +171,9 @@ async def async_request_openai_completions(
|
|
99
171
|
continue
|
100
172
|
|
101
173
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
174
|
+
latency = time.perf_counter() - st
|
102
175
|
if chunk == "[DONE]":
|
103
|
-
|
176
|
+
pass
|
104
177
|
else:
|
105
178
|
data = json.loads(chunk)
|
106
179
|
|
@@ -123,6 +196,7 @@ async def async_request_openai_completions(
|
|
123
196
|
output.generated_text = generated_text
|
124
197
|
output.success = True
|
125
198
|
output.latency = latency
|
199
|
+
output.output_len = request_func_input.output_len
|
126
200
|
else:
|
127
201
|
output.error = response.reason or ""
|
128
202
|
output.success = False
|
@@ -167,6 +241,7 @@ ASYNC_REQUEST_FUNCS = {
|
|
167
241
|
"sglang": async_request_openai_completions,
|
168
242
|
"vllm": async_request_openai_completions,
|
169
243
|
"lmdeploy": async_request_openai_completions,
|
244
|
+
"trt": async_request_trt_llm,
|
170
245
|
}
|
171
246
|
|
172
247
|
|
@@ -175,9 +250,11 @@ class BenchmarkMetrics:
|
|
175
250
|
completed: int
|
176
251
|
total_input: int
|
177
252
|
total_output: int
|
253
|
+
total_output_retokenized: int
|
178
254
|
request_throughput: float
|
179
255
|
input_throughput: float
|
180
256
|
output_throughput: float
|
257
|
+
output_throughput_retokenized: float
|
181
258
|
mean_ttft_ms: float
|
182
259
|
median_ttft_ms: float
|
183
260
|
std_ttft_ms: float
|
@@ -190,6 +267,8 @@ class BenchmarkMetrics:
|
|
190
267
|
median_itl_ms: float
|
191
268
|
std_itl_ms: float
|
192
269
|
p99_itl_ms: float
|
270
|
+
mean_e2e_latency_ms: float
|
271
|
+
median_e2e_latency_ms: float
|
193
272
|
|
194
273
|
|
195
274
|
default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
|
@@ -384,31 +463,36 @@ def calculate_metrics(
|
|
384
463
|
outputs: List[RequestFuncOutput],
|
385
464
|
dur_s: float,
|
386
465
|
tokenizer: PreTrainedTokenizerBase,
|
466
|
+
backend: str,
|
387
467
|
) -> Tuple[BenchmarkMetrics, List[int]]:
|
388
|
-
|
468
|
+
output_lens: List[int] = []
|
469
|
+
retokenized_output_lens: List[int] = []
|
389
470
|
total_input = 0
|
390
471
|
completed = 0
|
391
472
|
itls: List[float] = []
|
392
473
|
tpots: List[float] = []
|
393
474
|
ttfts: List[float] = []
|
475
|
+
e2e_latencies: List[float] = []
|
394
476
|
for i in range(len(outputs)):
|
395
477
|
if outputs[i].success:
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
# Note : this may inflate the output token count slightly
|
400
|
-
output_len = len(
|
478
|
+
output_len = outputs[i].output_len
|
479
|
+
output_lens.append(output_len)
|
480
|
+
retokenized_output_len = len(
|
401
481
|
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
|
402
482
|
)
|
403
|
-
|
483
|
+
retokenized_output_lens.append(retokenized_output_len)
|
404
484
|
total_input += input_requests[i][1]
|
405
485
|
if output_len > 1:
|
406
486
|
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
407
487
|
itls += outputs[i].itl
|
408
488
|
ttfts.append(outputs[i].ttft)
|
489
|
+
|
490
|
+
e2e_latencies.append(outputs[i].latency)
|
491
|
+
|
409
492
|
completed += 1
|
410
493
|
else:
|
411
|
-
|
494
|
+
output_lens.append(0)
|
495
|
+
retokenized_output_lens.append(0)
|
412
496
|
|
413
497
|
if completed == 0:
|
414
498
|
warnings.warn(
|
@@ -419,10 +503,12 @@ def calculate_metrics(
|
|
419
503
|
metrics = BenchmarkMetrics(
|
420
504
|
completed=completed,
|
421
505
|
total_input=total_input,
|
422
|
-
total_output=sum(
|
506
|
+
total_output=sum(output_lens),
|
507
|
+
total_output_retokenized=sum(retokenized_output_lens),
|
423
508
|
request_throughput=completed / dur_s,
|
424
509
|
input_throughput=total_input / dur_s,
|
425
|
-
output_throughput=sum(
|
510
|
+
output_throughput=sum(output_lens) / dur_s,
|
511
|
+
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
|
426
512
|
mean_ttft_ms=np.mean(ttfts or 0)
|
427
513
|
* 1000, # ttfts is empty if streaming is not supported by backend
|
428
514
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
@@ -436,9 +522,11 @@ def calculate_metrics(
|
|
436
522
|
median_itl_ms=np.median(itls or 0) * 1000,
|
437
523
|
std_itl_ms=np.std(itls or 0) * 1000,
|
438
524
|
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
525
|
+
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
|
526
|
+
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
439
527
|
)
|
440
528
|
|
441
|
-
return metrics,
|
529
|
+
return metrics, output_lens
|
442
530
|
|
443
531
|
|
444
532
|
async def benchmark(
|
@@ -449,6 +537,7 @@ async def benchmark(
|
|
449
537
|
input_requests: List[Tuple[str, int, int]],
|
450
538
|
request_rate: float,
|
451
539
|
disable_tqdm: bool,
|
540
|
+
enable_multi: bool,
|
452
541
|
):
|
453
542
|
if backend in ASYNC_REQUEST_FUNCS:
|
454
543
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
@@ -498,19 +587,26 @@ async def benchmark(
|
|
498
587
|
|
499
588
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
500
589
|
|
501
|
-
metrics,
|
590
|
+
metrics, output_lens = calculate_metrics(
|
502
591
|
input_requests=input_requests,
|
503
592
|
outputs=outputs,
|
504
593
|
dur_s=benchmark_duration,
|
505
594
|
tokenizer=tokenizer,
|
595
|
+
backend=backend,
|
506
596
|
)
|
507
597
|
|
508
598
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
599
|
+
print("{:<40} {:<10}".format("Backend:", backend))
|
509
600
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
510
601
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
511
602
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
512
603
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
513
604
|
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
605
|
+
print(
|
606
|
+
"{:<40} {:<10}".format(
|
607
|
+
"Total generated tokens (retokenized):", metrics.total_output_retokenized
|
608
|
+
)
|
609
|
+
)
|
514
610
|
print(
|
515
611
|
"{:<40} {:<10.2f}".format(
|
516
612
|
"Request throughput (req/s):", metrics.request_throughput
|
@@ -526,6 +622,15 @@ async def benchmark(
|
|
526
622
|
"Output token throughput (tok/s):", metrics.output_throughput
|
527
623
|
)
|
528
624
|
)
|
625
|
+
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
626
|
+
print(
|
627
|
+
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
628
|
+
)
|
629
|
+
print(
|
630
|
+
"{:<40} {:<10.2f}".format(
|
631
|
+
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
|
632
|
+
)
|
633
|
+
)
|
529
634
|
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
|
530
635
|
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
531
636
|
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
@@ -542,11 +647,53 @@ async def benchmark(
|
|
542
647
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
543
648
|
print("=" * 50)
|
544
649
|
|
650
|
+
if (
|
651
|
+
metrics.median_ttft_ms is not None
|
652
|
+
and metrics.mean_itl_ms is not None
|
653
|
+
and metrics.output_throughput is not None
|
654
|
+
):
|
655
|
+
result = {
|
656
|
+
"backend": args.backend,
|
657
|
+
"dataset_name": args.dataset_name,
|
658
|
+
"request_rate": request_rate,
|
659
|
+
"total_input": metrics.total_input,
|
660
|
+
"total_output": metrics.total_output,
|
661
|
+
"total_output_retokenized": metrics.total_output_retokenized,
|
662
|
+
"mean_e2e_latency": metrics.mean_e2e_latency_ms,
|
663
|
+
"median_e2e_latency": metrics.median_e2e_latency_ms,
|
664
|
+
"median_ttft": metrics.median_ttft_ms,
|
665
|
+
"median_itl": metrics.median_itl_ms,
|
666
|
+
"output_token_throughput": metrics.output_throughput,
|
667
|
+
"sharegpt_output_len": args.sharegpt_output_len,
|
668
|
+
"random_input_len": args.random_input_len,
|
669
|
+
"random_output_len": args.random_output_len,
|
670
|
+
"random_range_ratio": args.random_range_ratio,
|
671
|
+
"benchmark_duration": benchmark_duration,
|
672
|
+
}
|
673
|
+
else:
|
674
|
+
print(f"Error running benchmark for request rate: {request_rate}")
|
675
|
+
print("-" * 30)
|
676
|
+
|
677
|
+
# Determine output file name
|
678
|
+
if args.output_file:
|
679
|
+
output_file_name = args.output_file
|
680
|
+
else:
|
681
|
+
now = datetime.now().strftime("%m%d")
|
682
|
+
if args.dataset_name == "random":
|
683
|
+
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
684
|
+
else:
|
685
|
+
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
686
|
+
|
687
|
+
# Append results to a JSONL file
|
688
|
+
with open(output_file_name, "a") as file:
|
689
|
+
file.write(json.dumps(result) + "\n")
|
690
|
+
|
545
691
|
result = {
|
546
692
|
"duration": benchmark_duration,
|
547
693
|
"completed": metrics.completed,
|
548
694
|
"total_input_tokens": metrics.total_input,
|
549
695
|
"total_output_tokens": metrics.total_output,
|
696
|
+
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
550
697
|
"request_throughput": metrics.request_throughput,
|
551
698
|
"input_throughput": metrics.input_throughput,
|
552
699
|
"output_throughput": metrics.output_throughput,
|
@@ -563,15 +710,34 @@ async def benchmark(
|
|
563
710
|
"std_itl_ms": metrics.std_itl_ms,
|
564
711
|
"p99_itl_ms": metrics.p99_itl_ms,
|
565
712
|
"input_lens": [output.prompt_len for output in outputs],
|
566
|
-
"output_lens":
|
713
|
+
"output_lens": output_lens,
|
567
714
|
"ttfts": [output.ttft for output in outputs],
|
568
715
|
"itls": [output.itl for output in outputs],
|
569
716
|
"generated_texts": [output.generated_text for output in outputs],
|
570
717
|
"errors": [output.error for output in outputs],
|
718
|
+
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
719
|
+
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
571
720
|
}
|
572
721
|
return result
|
573
722
|
|
574
723
|
|
724
|
+
def parse_request_rate_range(request_rate_range):
|
725
|
+
if len(request_rate_range.split(",")) == 3:
|
726
|
+
start, stop, step = map(int, request_rate_range.split(","))
|
727
|
+
return list(range(start, stop, step))
|
728
|
+
else:
|
729
|
+
return list(map(int, request_rate_range.split(",")))
|
730
|
+
|
731
|
+
|
732
|
+
def check_chat_template(model_path):
|
733
|
+
try:
|
734
|
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
735
|
+
return "chat_template" in tokenizer.init_kwargs
|
736
|
+
except Exception as e:
|
737
|
+
print(f"Fail to load tokenizer config with error={e}")
|
738
|
+
return False
|
739
|
+
|
740
|
+
|
575
741
|
def fire(args: argparse.Namespace):
|
576
742
|
random.seed(args.seed)
|
577
743
|
np.random.seed(args.seed)
|
@@ -581,6 +747,7 @@ def fire(args: argparse.Namespace):
|
|
581
747
|
"sglang": 30000,
|
582
748
|
"lmdeploy": 23333,
|
583
749
|
"vllm": 8000,
|
750
|
+
"trt": 8000,
|
584
751
|
}.get(args.backend, 30000)
|
585
752
|
|
586
753
|
api_url = (
|
@@ -594,6 +761,16 @@ def fire(args: argparse.Namespace):
|
|
594
761
|
else f"http://{args.host}:{args.port}/v1/models"
|
595
762
|
)
|
596
763
|
|
764
|
+
if args.backend == "trt":
|
765
|
+
api_url = (
|
766
|
+
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
767
|
+
if args.base_url
|
768
|
+
else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
|
769
|
+
)
|
770
|
+
if args.model is None:
|
771
|
+
print("Please provide a model using `--model` when using `trt` backend.")
|
772
|
+
sys.exit(1)
|
773
|
+
|
597
774
|
if args.model is None:
|
598
775
|
try:
|
599
776
|
response = requests.get(model_url)
|
@@ -610,6 +787,12 @@ def fire(args: argparse.Namespace):
|
|
610
787
|
print("No model specified or found. Please provide a model using `--model`.")
|
611
788
|
sys.exit(1)
|
612
789
|
|
790
|
+
if not check_chat_template(args.model):
|
791
|
+
print(
|
792
|
+
"\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
|
793
|
+
"Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
|
794
|
+
)
|
795
|
+
|
613
796
|
print(f"{args}\n")
|
614
797
|
|
615
798
|
backend = args.backend
|
@@ -637,17 +820,35 @@ def fire(args: argparse.Namespace):
|
|
637
820
|
else:
|
638
821
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
639
822
|
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
823
|
+
if args.multi:
|
824
|
+
request_rates = parse_request_rate_range(args.request_rate_range)
|
825
|
+
|
826
|
+
for rate in request_rates:
|
827
|
+
asyncio.run(
|
828
|
+
benchmark(
|
829
|
+
backend=backend,
|
830
|
+
api_url=api_url,
|
831
|
+
model_id=model_id,
|
832
|
+
tokenizer=tokenizer,
|
833
|
+
input_requests=input_requests,
|
834
|
+
request_rate=rate,
|
835
|
+
disable_tqdm=args.disable_tqdm,
|
836
|
+
enable_multi=args.multi,
|
837
|
+
)
|
838
|
+
)
|
839
|
+
else:
|
840
|
+
asyncio.run(
|
841
|
+
benchmark(
|
842
|
+
backend=backend,
|
843
|
+
api_url=api_url,
|
844
|
+
model_id=model_id,
|
845
|
+
tokenizer=tokenizer,
|
846
|
+
input_requests=input_requests,
|
847
|
+
request_rate=args.request_rate,
|
848
|
+
disable_tqdm=args.disable_tqdm,
|
849
|
+
enable_multi=args.multi,
|
850
|
+
)
|
649
851
|
)
|
650
|
-
)
|
651
852
|
|
652
853
|
|
653
854
|
# to avoid relying on SGLang's components
|
@@ -751,6 +952,23 @@ if __name__ == "__main__":
|
|
751
952
|
action="store_true",
|
752
953
|
help="Specify to disable tqdm progress bar.",
|
753
954
|
)
|
955
|
+
parser.add_argument(
|
956
|
+
"--multi",
|
957
|
+
action="store_true",
|
958
|
+
help="Use request rate range rather than single value.",
|
959
|
+
)
|
960
|
+
parser.add_argument(
|
961
|
+
"--request-rate-range",
|
962
|
+
type=str,
|
963
|
+
default="2,34,2",
|
964
|
+
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
|
965
|
+
)
|
966
|
+
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
967
|
+
parser.add_argument(
|
968
|
+
"--disable-stream",
|
969
|
+
action="store_true",
|
970
|
+
help="Disable streaming mode.",
|
971
|
+
)
|
754
972
|
|
755
973
|
set_ulimit()
|
756
974
|
|
sglang/global_config.py
CHANGED
@@ -16,9 +16,9 @@ class GlobalConfig:
|
|
16
16
|
self.wait_for_new_request_delay = 0.0006
|
17
17
|
|
18
18
|
# Runtime constants: New generation token ratio estimation
|
19
|
-
self.
|
19
|
+
self.init_new_token_ratio = 0.7
|
20
20
|
self.base_min_new_token_ratio = 0.2
|
21
|
-
self.new_token_ratio_decay = 0.
|
21
|
+
self.new_token_ratio_decay = 0.001
|
22
22
|
self.new_token_ratio_recovery = 0.05
|
23
23
|
|
24
24
|
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
@@ -27,6 +27,7 @@ class GlobalConfig:
|
|
27
27
|
|
28
28
|
# Runtime constants: others
|
29
29
|
self.num_continue_decode_steps = 10
|
30
|
+
self.retract_decode_steps = 20
|
30
31
|
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
31
32
|
|
32
33
|
# Output tokenization configs
|
sglang/lang/interpreter.py
CHANGED
@@ -288,6 +288,7 @@ class StreamExecutor:
|
|
288
288
|
exes[i].text_ = str(self.text_)
|
289
289
|
exes[i].messages_ = list(self.messages_)
|
290
290
|
exes[i].cur_role = self.cur_role
|
291
|
+
exes[i].cur_role_begin_pos = self.cur_role_begin_pos
|
291
292
|
exes[i].fork_start_text_pos = len(self.text_)
|
292
293
|
exes[i].images_ = list(self.images_)
|
293
294
|
|
@@ -4,19 +4,26 @@ import functools
|
|
4
4
|
import json
|
5
5
|
import os
|
6
6
|
import warnings
|
7
|
-
from typing import AbstractSet, Collection, Literal, Optional, Union
|
7
|
+
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
|
8
8
|
|
9
9
|
from huggingface_hub import snapshot_download
|
10
10
|
from transformers import (
|
11
11
|
AutoConfig,
|
12
12
|
AutoProcessor,
|
13
13
|
AutoTokenizer,
|
14
|
+
PretrainedConfig,
|
14
15
|
PreTrainedTokenizer,
|
15
16
|
PreTrainedTokenizerFast,
|
16
17
|
)
|
18
|
+
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
17
19
|
|
18
20
|
from sglang.srt.utils import is_multimodal_model
|
19
21
|
|
22
|
+
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
23
|
+
ChatGLMConfig.model_type: ChatGLMConfig,
|
24
|
+
DbrxConfig.model_type: DbrxConfig,
|
25
|
+
}
|
26
|
+
|
20
27
|
|
21
28
|
def download_from_hf(model_path: str):
|
22
29
|
if os.path.exists(model_path):
|
@@ -40,6 +47,9 @@ def get_config(
|
|
40
47
|
config = AutoConfig.from_pretrained(
|
41
48
|
model, trust_remote_code=trust_remote_code, revision=revision
|
42
49
|
)
|
50
|
+
if config.model_type in _CONFIG_REGISTRY:
|
51
|
+
config_class = _CONFIG_REGISTRY[config.model_type]
|
52
|
+
config = config_class.from_pretrained(model, revision=revision)
|
43
53
|
if model_overide_args:
|
44
54
|
config.update(model_overide_args)
|
45
55
|
return config
|
@@ -63,6 +73,8 @@ def get_context_length(config):
|
|
63
73
|
rope_scaling = getattr(config, "rope_scaling", None)
|
64
74
|
if rope_scaling:
|
65
75
|
rope_scaling_factor = config.rope_scaling["factor"]
|
76
|
+
if config.rope_scaling["rope_type"] == "llama3":
|
77
|
+
rope_scaling_factor = 1
|
66
78
|
else:
|
67
79
|
rope_scaling_factor = 1
|
68
80
|
|
@@ -34,12 +34,11 @@ class LogitProcessorOutput:
|
|
34
34
|
@dataclasses.dataclass
|
35
35
|
class LogitsMetadata:
|
36
36
|
forward_mode: ForwardMode
|
37
|
-
extend_seq_lens: torch.Tensor
|
38
|
-
extend_start_loc: torch.Tensor
|
39
|
-
|
40
|
-
# For logprobs
|
41
37
|
return_logprob: bool
|
42
|
-
|
38
|
+
|
39
|
+
extend_seq_lens: torch.Tensor = None
|
40
|
+
extend_start_loc: torch.Tensor = None
|
41
|
+
top_logprobs_nums: List[int] = None
|
43
42
|
|
44
43
|
@classmethod
|
45
44
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|