sglang 0.3.1.post1__py3-none-any.whl → 0.3.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +11 -2
- sglang/bench_server_latency.py +187 -0
- sglang/bench_serving.py +1 -1
- sglang/srt/layers/activation.py +8 -4
- sglang/srt/layers/attention_backend.py +3 -1
- sglang/srt/layers/layernorm.py +10 -7
- sglang/srt/layers/linear.py +1133 -0
- sglang/srt/layers/quantization/__init__.py +76 -0
- sglang/srt/layers/quantization/base_config.py +122 -0
- sglang/srt/layers/sampler.py +9 -2
- sglang/srt/managers/io_struct.py +3 -0
- sglang/srt/managers/policy_scheduler.py +49 -93
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/tp_worker.py +11 -6
- sglang/srt/model_executor/cuda_graph_runner.py +15 -14
- sglang/srt/model_executor/model_runner.py +13 -5
- sglang/srt/models/baichuan.py +1 -1
- sglang/srt/models/chatglm.py +6 -6
- sglang/srt/models/commandr.py +7 -7
- sglang/srt/models/dbrx.py +7 -7
- sglang/srt/models/deepseek.py +7 -7
- sglang/srt/models/deepseek_v2.py +9 -9
- sglang/srt/models/exaone.py +6 -6
- sglang/srt/models/gemma.py +6 -6
- sglang/srt/models/gemma2.py +6 -6
- sglang/srt/models/gpt_bigcode.py +6 -6
- sglang/srt/models/grok.py +6 -6
- sglang/srt/models/internlm2.py +6 -6
- sglang/srt/models/llama.py +7 -9
- sglang/srt/models/llama_classification.py +3 -4
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +6 -6
- sglang/srt/models/minicpm3.py +3 -3
- sglang/srt/models/mixtral.py +6 -6
- sglang/srt/models/mixtral_quant.py +6 -6
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen.py +6 -6
- sglang/srt/models/qwen2.py +6 -6
- sglang/srt/models/qwen2_moe.py +7 -7
- sglang/srt/models/stablelm.py +6 -6
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +2 -5
- sglang/srt/models/yivl.py +1 -1
- sglang/srt/server_args.py +17 -21
- sglang/srt/utils.py +21 -1
- sglang/test/few_shot_gsm8k.py +8 -2
- sglang/test/test_utils.py +5 -2
- sglang/version.py +1 -1
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/METADATA +5 -5
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/RECORD +54 -50
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post1.dist-info → sglang-0.3.1.post3.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""
|
2
|
-
Benchmark the latency of a
|
2
|
+
Benchmark the latency of running a single static batch.
|
3
|
+
This script does not launch a server and uses the low-level APIs.
|
4
|
+
It accepts arguments similar to those of launch_server.py.
|
3
5
|
|
4
6
|
# Usage (latency test)
|
5
7
|
## with dummy weights:
|
@@ -62,8 +64,13 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
62
64
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
63
65
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
64
66
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
67
|
+
from sglang.srt.server import _set_envs_and_config
|
65
68
|
from sglang.srt.server_args import ServerArgs
|
66
|
-
from sglang.srt.utils import
|
69
|
+
from sglang.srt.utils import (
|
70
|
+
configure_logger,
|
71
|
+
kill_child_process,
|
72
|
+
suppress_other_loggers,
|
73
|
+
)
|
67
74
|
|
68
75
|
|
69
76
|
@dataclasses.dataclass
|
@@ -339,6 +346,8 @@ def latency_test(
|
|
339
346
|
bench_args,
|
340
347
|
tp_rank,
|
341
348
|
):
|
349
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
350
|
+
_set_envs_and_config(server_args)
|
342
351
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
343
352
|
|
344
353
|
# Load the model
|
@@ -0,0 +1,187 @@
|
|
1
|
+
"""
|
2
|
+
Benchmark the latency of serving a single batch with a real server.
|
3
|
+
This script launches a server and uses the HTTP interface.
|
4
|
+
It accepts arguments similar to those of launch_server.py.
|
5
|
+
|
6
|
+
Usage:
|
7
|
+
|
8
|
+
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
9
|
+
"""
|
10
|
+
|
11
|
+
import argparse
|
12
|
+
import dataclasses
|
13
|
+
import itertools
|
14
|
+
import json
|
15
|
+
import multiprocessing
|
16
|
+
import os
|
17
|
+
import time
|
18
|
+
from typing import Tuple
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
import requests
|
22
|
+
|
23
|
+
from sglang.srt.server import launch_server
|
24
|
+
from sglang.srt.server_args import ServerArgs
|
25
|
+
from sglang.srt.utils import kill_child_process
|
26
|
+
|
27
|
+
|
28
|
+
@dataclasses.dataclass
|
29
|
+
class BenchArgs:
|
30
|
+
run_name: str = "default"
|
31
|
+
batch_size: Tuple[int] = (1,)
|
32
|
+
input_len: Tuple[int] = (1024,)
|
33
|
+
output_len: Tuple[int] = (16,)
|
34
|
+
result_filename: str = "result.jsonl"
|
35
|
+
|
36
|
+
@staticmethod
|
37
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
38
|
+
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
39
|
+
parser.add_argument(
|
40
|
+
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
41
|
+
)
|
42
|
+
parser.add_argument(
|
43
|
+
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
44
|
+
)
|
45
|
+
parser.add_argument(
|
46
|
+
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
47
|
+
)
|
48
|
+
parser.add_argument(
|
49
|
+
"--result-filename", type=str, default=BenchArgs.result_filename
|
50
|
+
)
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
54
|
+
# use the default value's type to case the args into correct types.
|
55
|
+
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
56
|
+
return cls(
|
57
|
+
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def launch_server_internal(server_args):
|
62
|
+
try:
|
63
|
+
launch_server(server_args)
|
64
|
+
except Exception as e:
|
65
|
+
raise e
|
66
|
+
finally:
|
67
|
+
kill_child_process(os.getpid(), including_parent=False)
|
68
|
+
|
69
|
+
|
70
|
+
def launch_server_process(server_args: ServerArgs):
|
71
|
+
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
72
|
+
proc.start()
|
73
|
+
base_url = f"http://{server_args.host}:{server_args.port}"
|
74
|
+
timeout = 600
|
75
|
+
|
76
|
+
start_time = time.time()
|
77
|
+
while time.time() - start_time < timeout:
|
78
|
+
try:
|
79
|
+
headers = {
|
80
|
+
"Content-Type": "application/json; charset=utf-8",
|
81
|
+
}
|
82
|
+
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
83
|
+
if response.status_code == 200:
|
84
|
+
return proc, base_url
|
85
|
+
except requests.RequestException:
|
86
|
+
pass
|
87
|
+
time.sleep(10)
|
88
|
+
raise TimeoutError("Server failed to start within the timeout period.")
|
89
|
+
|
90
|
+
|
91
|
+
def run_one_case(
|
92
|
+
url: str,
|
93
|
+
batch_size: int,
|
94
|
+
input_len: int,
|
95
|
+
output_len: int,
|
96
|
+
run_name: str,
|
97
|
+
result_filename: str,
|
98
|
+
):
|
99
|
+
input_ids = [
|
100
|
+
[int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
|
101
|
+
for _ in range(batch_size)
|
102
|
+
]
|
103
|
+
|
104
|
+
tic = time.time()
|
105
|
+
response = requests.post(
|
106
|
+
url + "/generate",
|
107
|
+
json={
|
108
|
+
"input_ids": input_ids,
|
109
|
+
"sampling_params": {
|
110
|
+
"temperature": 0,
|
111
|
+
"max_new_tokens": output_len,
|
112
|
+
"ignore_eos": True,
|
113
|
+
},
|
114
|
+
},
|
115
|
+
)
|
116
|
+
latency = time.time() - tic
|
117
|
+
|
118
|
+
_ = response.json()
|
119
|
+
output_throughput = batch_size * output_len / latency
|
120
|
+
overall_throughput = batch_size * (input_len + output_len) / latency
|
121
|
+
|
122
|
+
print(f"batch size: {batch_size}")
|
123
|
+
print(f"latency: {latency:.2f} s")
|
124
|
+
print(f"output throughput: {output_throughput:.2f} token/s")
|
125
|
+
print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
|
126
|
+
|
127
|
+
if result_filename:
|
128
|
+
with open(result_filename, "a") as fout:
|
129
|
+
res = {
|
130
|
+
"run_name": run_name,
|
131
|
+
"batch_size": batch_size,
|
132
|
+
"input_len": input_len,
|
133
|
+
"output_len": output_len,
|
134
|
+
"latency": round(latency, 4),
|
135
|
+
"output_throughput": round(output_throughput, 2),
|
136
|
+
"overall_throughput": round(overall_throughput, 2),
|
137
|
+
}
|
138
|
+
fout.write(json.dumps(res) + "\n")
|
139
|
+
|
140
|
+
|
141
|
+
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
142
|
+
proc, base_url = launch_server_process(server_args)
|
143
|
+
|
144
|
+
# warmup
|
145
|
+
run_one_case(
|
146
|
+
base_url,
|
147
|
+
batch_size=16,
|
148
|
+
input_len=1024,
|
149
|
+
output_len=16,
|
150
|
+
run_name="",
|
151
|
+
result_filename="",
|
152
|
+
)
|
153
|
+
|
154
|
+
# benchmark
|
155
|
+
try:
|
156
|
+
for bs, il, ol in itertools.product(
|
157
|
+
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
158
|
+
):
|
159
|
+
run_one_case(
|
160
|
+
base_url,
|
161
|
+
bs,
|
162
|
+
il,
|
163
|
+
ol,
|
164
|
+
bench_args.run_name,
|
165
|
+
bench_args.result_filename,
|
166
|
+
)
|
167
|
+
finally:
|
168
|
+
kill_child_process(proc.pid)
|
169
|
+
|
170
|
+
print(f"\nResults are saved to {bench_args.result_filename}")
|
171
|
+
|
172
|
+
|
173
|
+
if __name__ == "__main__":
|
174
|
+
parser = argparse.ArgumentParser()
|
175
|
+
ServerArgs.add_cli_args(parser)
|
176
|
+
BenchArgs.add_cli_args(parser)
|
177
|
+
# For this script, model-path is not required
|
178
|
+
assert (
|
179
|
+
parser._actions[1].option_strings[0] == "--model-path"
|
180
|
+
), "options changed, this code need to be updated"
|
181
|
+
parser._actions[1].required = False
|
182
|
+
args = parser.parse_args()
|
183
|
+
|
184
|
+
server_args = ServerArgs.from_cli_args(args)
|
185
|
+
bench_args = BenchArgs.from_cli_args(args)
|
186
|
+
|
187
|
+
run_benchmark(server_args, bench_args)
|
sglang/bench_serving.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
|
3
3
|
|
4
4
|
"""
|
5
|
-
Benchmark online serving.
|
5
|
+
Benchmark online serving with dynamic requests.
|
6
6
|
|
7
7
|
Usage:
|
8
8
|
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
sglang/srt/layers/activation.py
CHANGED
@@ -19,17 +19,21 @@ from typing import Optional
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
21
|
import torch.nn.functional as F
|
22
|
-
|
22
|
+
|
23
|
+
from sglang.srt.utils import is_hip
|
24
|
+
|
25
|
+
if not is_hip():
|
26
|
+
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
27
|
+
|
23
28
|
from vllm.distributed import (
|
24
29
|
divide,
|
25
30
|
get_tensor_model_parallel_rank,
|
26
31
|
get_tensor_model_parallel_world_size,
|
27
32
|
)
|
28
33
|
from vllm.model_executor.custom_op import CustomOp
|
29
|
-
from vllm.model_executor.layers.quantization import QuantizationConfig
|
30
|
-
from vllm.model_executor.utils import set_weight_attrs
|
31
34
|
|
32
|
-
from sglang.srt.
|
35
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
|
+
from sglang.srt.utils import set_weight_attrs
|
33
37
|
|
34
38
|
logger = logging.getLogger(__name__)
|
35
39
|
|
@@ -346,7 +346,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
346
346
|
|
347
347
|
self.decode_attention_fwd = decode_attention_fwd
|
348
348
|
self.extend_attention_fwd = extend_attention_fwd
|
349
|
-
self.num_head =
|
349
|
+
self.num_head = (
|
350
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
351
|
+
)
|
350
352
|
|
351
353
|
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
352
354
|
self.reduce_dtype = torch.float32
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -20,16 +20,19 @@ from typing import Optional, Tuple, Union
|
|
20
20
|
|
21
21
|
import torch
|
22
22
|
import torch.nn as nn
|
23
|
-
from flashinfer.norm import (
|
24
|
-
fused_add_rmsnorm,
|
25
|
-
gemma_fused_add_rmsnorm,
|
26
|
-
gemma_rmsnorm,
|
27
|
-
rmsnorm,
|
28
|
-
)
|
29
|
-
from vllm.model_executor.custom_op import CustomOp
|
30
23
|
|
31
24
|
from sglang.srt.utils import is_hip
|
32
25
|
|
26
|
+
if not is_hip():
|
27
|
+
from flashinfer.norm import (
|
28
|
+
fused_add_rmsnorm,
|
29
|
+
gemma_fused_add_rmsnorm,
|
30
|
+
gemma_rmsnorm,
|
31
|
+
rmsnorm,
|
32
|
+
)
|
33
|
+
|
34
|
+
from vllm.model_executor.custom_op import CustomOp
|
35
|
+
|
33
36
|
logger = logging.getLogger(__name__)
|
34
37
|
|
35
38
|
|