sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -259,7 +259,9 @@ def throughput_test_once(
|
|
259
259
|
measurement_results["total_input_tokens"]
|
260
260
|
+ measurement_results["total_output_tokens"]
|
261
261
|
) / latency
|
262
|
-
measurement_results["last_gen_throughput"] = server_info["
|
262
|
+
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
263
|
+
"last_gen_throughput"
|
264
|
+
]
|
263
265
|
|
264
266
|
return measurement_results
|
265
267
|
|
@@ -315,7 +317,7 @@ def throughput_test(
|
|
315
317
|
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
316
318
|
tokenizer = get_tokenizer(tokenizer_id)
|
317
319
|
|
318
|
-
# Set global
|
320
|
+
# Set global environments
|
319
321
|
set_ulimit()
|
320
322
|
random.seed(bench_args.seed)
|
321
323
|
np.random.seed(bench_args.seed)
|
sglang/bench_one_batch.py
CHANGED
@@ -137,17 +137,7 @@ def load_model(server_args, port_args, tp_rank):
|
|
137
137
|
suppress_other_loggers()
|
138
138
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
139
139
|
|
140
|
-
model_config = ModelConfig(
|
141
|
-
server_args.model_path,
|
142
|
-
trust_remote_code=server_args.trust_remote_code,
|
143
|
-
revision=server_args.revision,
|
144
|
-
context_length=server_args.context_length,
|
145
|
-
model_override_args=server_args.json_model_override_args,
|
146
|
-
is_embedding=server_args.is_embedding,
|
147
|
-
enable_multimodal=server_args.enable_multimodal,
|
148
|
-
dtype=server_args.dtype,
|
149
|
-
quantization=server_args.quantization,
|
150
|
-
)
|
140
|
+
model_config = ModelConfig.from_server_args(server_args)
|
151
141
|
model_runner = ModelRunner(
|
152
142
|
model_config=model_config,
|
153
143
|
mem_fraction_static=server_args.mem_fraction_static,
|
@@ -256,7 +246,7 @@ def extend(reqs, model_runner):
|
|
256
246
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
257
247
|
model_worker_batch = batch.get_model_worker_batch()
|
258
248
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
259
|
-
logits_output = model_runner.forward(forward_batch)
|
249
|
+
logits_output, _ = model_runner.forward(forward_batch)
|
260
250
|
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
261
251
|
return next_token_ids, logits_output.next_token_logits, batch
|
262
252
|
|
@@ -268,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
|
|
268
258
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
269
259
|
model_worker_batch = batch.get_model_worker_batch()
|
270
260
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
271
|
-
logits_output = model_runner.forward(forward_batch)
|
261
|
+
logits_output, _ = model_runner.forward(forward_batch)
|
272
262
|
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
273
263
|
return next_token_ids, logits_output.next_token_logits
|
274
264
|
|
sglang/bench_one_batch_server.py
CHANGED
@@ -25,6 +25,7 @@ import requests
|
|
25
25
|
from sglang.srt.entrypoints.http_server import launch_server
|
26
26
|
from sglang.srt.server_args import ServerArgs
|
27
27
|
from sglang.srt.utils import kill_process_tree
|
28
|
+
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
28
29
|
|
29
30
|
|
30
31
|
@dataclasses.dataclass
|
@@ -33,9 +34,13 @@ class BenchArgs:
|
|
33
34
|
batch_size: Tuple[int] = (1,)
|
34
35
|
input_len: Tuple[int] = (1024,)
|
35
36
|
output_len: Tuple[int] = (16,)
|
37
|
+
temperature: float = 0.0
|
38
|
+
return_logprob: bool = False
|
39
|
+
input_len_step_percentage: float = 0.0
|
36
40
|
result_filename: str = "result.jsonl"
|
37
41
|
base_url: str = ""
|
38
42
|
skip_warmup: bool = False
|
43
|
+
show_report: bool = False
|
39
44
|
|
40
45
|
@staticmethod
|
41
46
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -49,11 +54,19 @@ class BenchArgs:
|
|
49
54
|
parser.add_argument(
|
50
55
|
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
51
56
|
)
|
57
|
+
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
58
|
+
parser.add_argument("--return-logprob", action="store_true")
|
59
|
+
parser.add_argument(
|
60
|
+
"--input-len-step-percentage",
|
61
|
+
type=float,
|
62
|
+
default=BenchArgs.input_len_step_percentage,
|
63
|
+
)
|
52
64
|
parser.add_argument(
|
53
65
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
54
66
|
)
|
55
67
|
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
56
68
|
parser.add_argument("--skip-warmup", action="store_true")
|
69
|
+
parser.add_argument("--show-report", action="store_true")
|
57
70
|
|
58
71
|
@classmethod
|
59
72
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -99,36 +112,89 @@ def run_one_case(
|
|
99
112
|
batch_size: int,
|
100
113
|
input_len: int,
|
101
114
|
output_len: int,
|
115
|
+
temperature: float,
|
116
|
+
return_logprob: bool,
|
117
|
+
input_len_step_percentage: float,
|
102
118
|
run_name: str,
|
103
119
|
result_filename: str,
|
104
120
|
):
|
121
|
+
requests.post(url + "/flush_cache")
|
122
|
+
input_lens = [
|
123
|
+
int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
|
124
|
+
for i in range(batch_size)
|
125
|
+
]
|
105
126
|
input_ids = [
|
106
|
-
[int(x) for x in np.random.randint(0, high=16384, size=(
|
107
|
-
for
|
127
|
+
[int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
|
128
|
+
for i in range(batch_size)
|
108
129
|
]
|
109
130
|
|
131
|
+
use_structured_outputs = False
|
132
|
+
if use_structured_outputs:
|
133
|
+
texts = []
|
134
|
+
for _ in range(batch_size):
|
135
|
+
texts.append(
|
136
|
+
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
|
137
|
+
* 50
|
138
|
+
+ "Assistant:"
|
139
|
+
)
|
140
|
+
json_schema = "$$ANY$$"
|
141
|
+
else:
|
142
|
+
json_schema = None
|
143
|
+
|
110
144
|
tic = time.time()
|
111
145
|
response = requests.post(
|
112
146
|
url + "/generate",
|
113
147
|
json={
|
148
|
+
# "text": texts,
|
114
149
|
"input_ids": input_ids,
|
115
150
|
"sampling_params": {
|
116
|
-
"temperature":
|
151
|
+
"temperature": temperature,
|
117
152
|
"max_new_tokens": output_len,
|
118
153
|
"ignore_eos": True,
|
154
|
+
"json_schema": json_schema,
|
119
155
|
},
|
156
|
+
"return_logprob": return_logprob,
|
157
|
+
"stream": True,
|
120
158
|
},
|
159
|
+
stream=True,
|
121
160
|
)
|
122
|
-
latency = time.time() - tic
|
123
161
|
|
124
|
-
|
125
|
-
|
162
|
+
# The TTFT of the last request in the batch
|
163
|
+
ttft = 0.0
|
164
|
+
for chunk in response.iter_lines(decode_unicode=False):
|
165
|
+
chunk = chunk.decode("utf-8")
|
166
|
+
if chunk and chunk.startswith("data:"):
|
167
|
+
if chunk == "data: [DONE]":
|
168
|
+
break
|
169
|
+
data = json.loads(chunk[5:].strip("\n"))
|
170
|
+
if "error" in data:
|
171
|
+
raise RuntimeError(f"Request has failed. {data}.")
|
172
|
+
|
173
|
+
assert (
|
174
|
+
data["meta_info"]["finish_reason"] is None
|
175
|
+
or data["meta_info"]["finish_reason"]["type"] == "length"
|
176
|
+
)
|
177
|
+
if data["meta_info"]["completion_tokens"] == 1:
|
178
|
+
ttft = time.time() - tic
|
179
|
+
|
180
|
+
latency = time.time() - tic
|
181
|
+
input_throughput = batch_size * input_len / ttft
|
182
|
+
output_throughput = batch_size * output_len / (latency - ttft)
|
126
183
|
overall_throughput = batch_size * (input_len + output_len) / latency
|
127
184
|
|
185
|
+
server_info = requests.get(url + "/get_server_info").json()
|
186
|
+
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
|
187
|
+
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
|
188
|
+
|
128
189
|
print(f"batch size: {batch_size}")
|
190
|
+
print(f"input_len: {input_len}")
|
191
|
+
print(f"output_len: {output_len}")
|
129
192
|
print(f"latency: {latency:.2f} s")
|
130
|
-
print(f"
|
131
|
-
print(f"
|
193
|
+
print(f"ttft: {ttft:.2f} s")
|
194
|
+
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
|
195
|
+
print(f"Input throughput: {input_throughput:.2f} tok/s")
|
196
|
+
if output_len != 1:
|
197
|
+
print(f"output throughput: {output_throughput:.2f} tok/s")
|
132
198
|
|
133
199
|
if result_filename:
|
134
200
|
with open(result_filename, "a") as fout:
|
@@ -140,9 +206,21 @@ def run_one_case(
|
|
140
206
|
"latency": round(latency, 4),
|
141
207
|
"output_throughput": round(output_throughput, 2),
|
142
208
|
"overall_throughput": round(overall_throughput, 2),
|
209
|
+
"last_gen_throughput": round(last_gen_throughput, 2),
|
143
210
|
}
|
144
211
|
fout.write(json.dumps(res) + "\n")
|
145
212
|
|
213
|
+
return (
|
214
|
+
batch_size,
|
215
|
+
latency,
|
216
|
+
ttft,
|
217
|
+
input_throughput,
|
218
|
+
output_throughput,
|
219
|
+
overall_throughput,
|
220
|
+
last_gen_throughput,
|
221
|
+
acc_length,
|
222
|
+
)
|
223
|
+
|
146
224
|
|
147
225
|
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
148
226
|
if bench_args.base_url:
|
@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
152
230
|
|
153
231
|
# warmup
|
154
232
|
if not bench_args.skip_warmup:
|
233
|
+
print("=" * 8 + " Warmup Begin " + "=" * 8)
|
155
234
|
run_one_case(
|
156
235
|
base_url,
|
157
236
|
batch_size=16,
|
158
237
|
input_len=1024,
|
159
238
|
output_len=16,
|
239
|
+
temperature=bench_args.temperature,
|
240
|
+
return_logprob=bench_args.return_logprob,
|
241
|
+
input_len_step_percentage=bench_args.input_len_step_percentage,
|
160
242
|
run_name="",
|
161
243
|
result_filename="",
|
162
244
|
)
|
245
|
+
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
163
246
|
|
164
247
|
# benchmark
|
248
|
+
result = []
|
165
249
|
try:
|
166
250
|
for bs, il, ol in itertools.product(
|
167
251
|
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
168
252
|
):
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
253
|
+
result.append(
|
254
|
+
run_one_case(
|
255
|
+
base_url,
|
256
|
+
bs,
|
257
|
+
il,
|
258
|
+
ol,
|
259
|
+
temperature=bench_args.temperature,
|
260
|
+
return_logprob=bench_args.return_logprob,
|
261
|
+
input_len_step_percentage=bench_args.input_len_step_percentage,
|
262
|
+
run_name=bench_args.run_name,
|
263
|
+
result_filename=bench_args.result_filename,
|
264
|
+
)
|
176
265
|
)
|
177
266
|
finally:
|
178
267
|
if proc:
|
@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
180
269
|
|
181
270
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
182
271
|
|
272
|
+
if not bench_args.show_report:
|
273
|
+
return
|
274
|
+
|
275
|
+
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
|
276
|
+
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
|
277
|
+
|
278
|
+
for (
|
279
|
+
batch_size,
|
280
|
+
latency,
|
281
|
+
ttft,
|
282
|
+
input_throughput,
|
283
|
+
output_throughput,
|
284
|
+
overall_throughput,
|
285
|
+
last_gen_throughput,
|
286
|
+
acc_length,
|
287
|
+
) in result:
|
288
|
+
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
|
289
|
+
input_util = 0.7
|
290
|
+
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
|
291
|
+
line = (
|
292
|
+
f"| {batch_size} | "
|
293
|
+
f"{latency:.2f} | "
|
294
|
+
f"{input_throughput:.2f} | "
|
295
|
+
f"{output_throughput:.2f} | "
|
296
|
+
f"{accept_length} | "
|
297
|
+
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
|
298
|
+
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
|
299
|
+
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
|
300
|
+
)
|
301
|
+
summary += line
|
302
|
+
|
303
|
+
# print metrics table
|
304
|
+
print(summary)
|
305
|
+
|
306
|
+
if is_in_ci():
|
307
|
+
write_github_step_summary(
|
308
|
+
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
|
309
|
+
)
|
310
|
+
|
183
311
|
|
184
312
|
if __name__ == "__main__":
|
185
313
|
parser = argparse.ArgumentParser()
|
sglang/bench_serving.py
CHANGED
@@ -58,6 +58,7 @@ class RequestFuncInput:
|
|
58
58
|
output_len: int
|
59
59
|
model: str
|
60
60
|
lora_name: str
|
61
|
+
image_data: str
|
61
62
|
extra_request_body: Dict[str, Any]
|
62
63
|
|
63
64
|
|
@@ -347,6 +348,11 @@ async def async_request_sglang_generate(
|
|
347
348
|
"logprob_start_len": -1,
|
348
349
|
**request_func_input.extra_request_body,
|
349
350
|
}
|
351
|
+
|
352
|
+
# Add image data if available
|
353
|
+
if request_func_input.image_data:
|
354
|
+
payload["image_data"] = request_func_input.image_data
|
355
|
+
|
350
356
|
headers = get_auth_headers()
|
351
357
|
|
352
358
|
output = RequestFuncOutput()
|
@@ -510,6 +516,13 @@ def get_dataset(args, tokenizer):
|
|
510
516
|
tokenizer=tokenizer,
|
511
517
|
args=args,
|
512
518
|
)
|
519
|
+
elif args.dataset_name == "mmmu":
|
520
|
+
input_requests = sample_mmmu_requests(
|
521
|
+
num_requests=args.num_prompts,
|
522
|
+
tokenizer=tokenizer,
|
523
|
+
fixed_output_len=args.random_output_len,
|
524
|
+
random_sample=True,
|
525
|
+
)
|
513
526
|
else:
|
514
527
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
515
528
|
return input_requests
|
@@ -597,6 +610,121 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
|
|
597
610
|
return filename
|
598
611
|
|
599
612
|
|
613
|
+
def sample_mmmu_requests(
|
614
|
+
num_requests: int,
|
615
|
+
tokenizer: PreTrainedTokenizerBase,
|
616
|
+
fixed_output_len: Optional[int] = None,
|
617
|
+
random_sample: bool = True,
|
618
|
+
) -> List[Tuple[str, int, int]]:
|
619
|
+
"""
|
620
|
+
Sample requests from the MMMU dataset using HuggingFace datasets.
|
621
|
+
|
622
|
+
Args:
|
623
|
+
num_requests: Number of requests to sample.
|
624
|
+
tokenizer: Tokenizer to use for token counting.
|
625
|
+
fixed_output_len: If provided, use this fixed output length for all requests.
|
626
|
+
random_sample: Whether to randomly sample or take the first N.
|
627
|
+
|
628
|
+
Returns:
|
629
|
+
List of tuples (prompt, prompt_token_len, output_token_len).
|
630
|
+
"""
|
631
|
+
try:
|
632
|
+
import base64
|
633
|
+
import io
|
634
|
+
|
635
|
+
from datasets import load_dataset
|
636
|
+
except ImportError:
|
637
|
+
raise ImportError("Please install datasets: pip install datasets")
|
638
|
+
|
639
|
+
print("Loading MMMU dataset from HuggingFace...")
|
640
|
+
|
641
|
+
try:
|
642
|
+
print("Attempting to load MMMU Math dataset...")
|
643
|
+
mmmu_dataset = load_dataset("MMMU/MMMU", "Math", split="test")
|
644
|
+
print(
|
645
|
+
f"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples"
|
646
|
+
)
|
647
|
+
except Exception as e:
|
648
|
+
print(f"Failed to load MMMU Math dataset: {e}")
|
649
|
+
raise ValueError(f"Failed to load MMMU dataset: {e}")
|
650
|
+
|
651
|
+
# Sample from the dataset
|
652
|
+
if len(mmmu_dataset) > num_requests:
|
653
|
+
if random_sample:
|
654
|
+
# Random sample
|
655
|
+
indices = random.sample(range(len(mmmu_dataset)), num_requests)
|
656
|
+
sample_dataset = mmmu_dataset.select(indices)
|
657
|
+
else:
|
658
|
+
# Take first N
|
659
|
+
sample_dataset = mmmu_dataset.select(
|
660
|
+
range(min(num_requests, len(mmmu_dataset)))
|
661
|
+
)
|
662
|
+
else:
|
663
|
+
print(f"Dataset has less than {num_requests} examples, using all examples")
|
664
|
+
sample_dataset = mmmu_dataset
|
665
|
+
|
666
|
+
print(f"Selected {len(sample_dataset)} examples for benchmarking")
|
667
|
+
|
668
|
+
# Create prompts
|
669
|
+
filtered_dataset = []
|
670
|
+
|
671
|
+
for i, example in enumerate(sample_dataset):
|
672
|
+
try:
|
673
|
+
# Extract image_1
|
674
|
+
image = example.get("image_1")
|
675
|
+
|
676
|
+
if image is not None:
|
677
|
+
if hasattr(image, "save"):
|
678
|
+
# Convert RGBA images to RGB before encoding
|
679
|
+
if image.mode == "RGBA":
|
680
|
+
image = image.convert("RGB")
|
681
|
+
|
682
|
+
# Encode image to base64
|
683
|
+
buffered = io.BytesIO()
|
684
|
+
image.save(buffered, format="JPEG")
|
685
|
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
686
|
+
image_path = f"data:image/jpeg;base64,{img_str}"
|
687
|
+
else:
|
688
|
+
continue
|
689
|
+
|
690
|
+
# Extract the question
|
691
|
+
question = example.get("question")
|
692
|
+
|
693
|
+
# Create the prompt with image, question
|
694
|
+
prompt = f"Question: {question}\n\nAnswer: "
|
695
|
+
prompt = tokenizer.apply_chat_template(
|
696
|
+
[
|
697
|
+
{
|
698
|
+
"role": "user",
|
699
|
+
"content": [
|
700
|
+
{"type": "image_url", "image_url": {"url": image_path}},
|
701
|
+
{"type": "text", "text": prompt},
|
702
|
+
],
|
703
|
+
}
|
704
|
+
],
|
705
|
+
add_generation_prompt=True,
|
706
|
+
tokenize=False,
|
707
|
+
)
|
708
|
+
prompt = f"<image>{image_path}</image>{prompt}"
|
709
|
+
|
710
|
+
# Calculate token lengths
|
711
|
+
# Note: This is approximate since we're not rendering the actual image tokens
|
712
|
+
prompt_token_ids = tokenizer.encode(prompt)
|
713
|
+
prompt_len = (
|
714
|
+
len(prompt_token_ids) + 512
|
715
|
+
) # Add estimate for image tokens
|
716
|
+
|
717
|
+
output_len = fixed_output_len if fixed_output_len is not None else 256
|
718
|
+
|
719
|
+
filtered_dataset.append((prompt, prompt_len, output_len))
|
720
|
+
|
721
|
+
except Exception as e:
|
722
|
+
print(f"Error processing example {i}: {e}")
|
723
|
+
|
724
|
+
print(f"\nCreated {len(filtered_dataset)} MMMU prompts")
|
725
|
+
return filtered_dataset
|
726
|
+
|
727
|
+
|
600
728
|
def sample_sharegpt_requests(
|
601
729
|
dataset_path: str,
|
602
730
|
num_requests: int,
|
@@ -975,7 +1103,7 @@ async def benchmark(
|
|
975
1103
|
lora_names: List[str],
|
976
1104
|
extra_request_body: Dict[str, Any],
|
977
1105
|
profile: bool,
|
978
|
-
|
1106
|
+
pd_separated: bool = False,
|
979
1107
|
flush_cache: bool = False,
|
980
1108
|
warmup_requests: int = 1,
|
981
1109
|
):
|
@@ -1004,6 +1132,15 @@ async def benchmark(
|
|
1004
1132
|
else:
|
1005
1133
|
lora_name = None
|
1006
1134
|
|
1135
|
+
if "<image>" in test_prompt:
|
1136
|
+
import re
|
1137
|
+
|
1138
|
+
image_match = re.search(r"<image>(.*?)</image>(.*)", test_prompt)
|
1139
|
+
image_data = image_match.group(1) if image_match else None
|
1140
|
+
test_prompt = image_match.group(2) if image_match else test_prompt
|
1141
|
+
else:
|
1142
|
+
image_data = None
|
1143
|
+
|
1007
1144
|
# Create the test input once
|
1008
1145
|
test_input = RequestFuncInput(
|
1009
1146
|
model=model_id,
|
@@ -1012,6 +1149,7 @@ async def benchmark(
|
|
1012
1149
|
prompt_len=test_prompt_len,
|
1013
1150
|
output_len=min(test_output_len, 32),
|
1014
1151
|
lora_name=lora_name,
|
1152
|
+
image_data=image_data,
|
1015
1153
|
extra_request_body=extra_request_body,
|
1016
1154
|
)
|
1017
1155
|
|
@@ -1063,6 +1201,15 @@ async def benchmark(
|
|
1063
1201
|
else:
|
1064
1202
|
lora_name = None
|
1065
1203
|
|
1204
|
+
if "<image>" in prompt:
|
1205
|
+
import re
|
1206
|
+
|
1207
|
+
image_match = re.search(r"<image>(.*?)</image>(.*)", prompt)
|
1208
|
+
image_data = image_match.group(1) if image_match else None
|
1209
|
+
prompt = image_match.group(2) if image_match else prompt
|
1210
|
+
else:
|
1211
|
+
image_data = None
|
1212
|
+
|
1066
1213
|
request_func_input = RequestFuncInput(
|
1067
1214
|
model=model_id,
|
1068
1215
|
prompt=prompt,
|
@@ -1070,6 +1217,7 @@ async def benchmark(
|
|
1070
1217
|
prompt_len=prompt_len,
|
1071
1218
|
output_len=output_len,
|
1072
1219
|
lora_name=lora_name,
|
1220
|
+
image_data=image_data,
|
1073
1221
|
extra_request_body=extra_request_body,
|
1074
1222
|
)
|
1075
1223
|
tasks.append(
|
@@ -1091,12 +1239,14 @@ async def benchmark(
|
|
1091
1239
|
|
1092
1240
|
if "sglang" in backend:
|
1093
1241
|
server_info = requests.get(base_url + "/get_server_info")
|
1094
|
-
if
|
1095
|
-
accept_length = server_info.json()["decode"][0].get(
|
1242
|
+
if pd_separated:
|
1243
|
+
accept_length = server_info.json()["decode"][0]["internal_states"][0].get(
|
1096
1244
|
"avg_spec_accept_length", None
|
1097
1245
|
)
|
1098
1246
|
else:
|
1099
|
-
accept_length = server_info.json().get(
|
1247
|
+
accept_length = server_info.json()["internal_states"][0].get(
|
1248
|
+
"avg_spec_accept_length", None
|
1249
|
+
)
|
1100
1250
|
else:
|
1101
1251
|
accept_length = None
|
1102
1252
|
|
@@ -1115,7 +1265,7 @@ async def benchmark(
|
|
1115
1265
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
1116
1266
|
print(
|
1117
1267
|
"{:<40} {:<10}".format(
|
1118
|
-
"Max
|
1268
|
+
"Max request concurrency:",
|
1119
1269
|
max_concurrency if max_concurrency else "not set",
|
1120
1270
|
)
|
1121
1271
|
)
|
@@ -1393,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1393
1543
|
lora_names=args.lora_name,
|
1394
1544
|
extra_request_body=extra_request_body,
|
1395
1545
|
profile=args.profile,
|
1396
|
-
|
1546
|
+
pd_separated=args.pd_separated,
|
1397
1547
|
flush_cache=args.flush_cache,
|
1398
1548
|
)
|
1399
1549
|
)
|
@@ -1444,7 +1594,7 @@ if __name__ == "__main__":
|
|
1444
1594
|
"--dataset-name",
|
1445
1595
|
type=str,
|
1446
1596
|
default="sharegpt",
|
1447
|
-
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
|
1597
|
+
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix", "mmmu"],
|
1448
1598
|
help="Name of the dataset to benchmark on.",
|
1449
1599
|
)
|
1450
1600
|
parser.add_argument(
|
@@ -1572,7 +1722,7 @@ if __name__ == "__main__":
|
|
1572
1722
|
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
1573
1723
|
)
|
1574
1724
|
parser.add_argument(
|
1575
|
-
"--pd-
|
1725
|
+
"--pd-separated",
|
1576
1726
|
action="store_true",
|
1577
1727
|
help="Benchmark PD disaggregation server",
|
1578
1728
|
)
|
sglang/compile_deep_gemm.py
CHANGED
@@ -129,7 +129,7 @@ def launch_server_process_and_send_one_request(
|
|
129
129
|
|
130
130
|
|
131
131
|
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
132
|
-
#
|
132
|
+
# Disable cuda graph and torch compile to save time
|
133
133
|
server_args.disable_cuda_graph = True
|
134
134
|
server_args.enable_torch_compile = False
|
135
135
|
print(f"Disable CUDA Graph and Torch Compile to save time...")
|