sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,12 @@ import logging
|
|
17
17
|
import os
|
18
18
|
import random
|
19
19
|
import time
|
20
|
-
from typing import Dict, List, Optional
|
20
|
+
from typing import Dict, List, Optional
|
21
21
|
|
22
22
|
import numpy as np
|
23
23
|
|
24
24
|
from sglang.bench_serving import (
|
25
|
+
DatasetRow,
|
25
26
|
get_dataset,
|
26
27
|
get_tokenizer,
|
27
28
|
sample_random_requests,
|
@@ -194,7 +195,7 @@ class BenchArgs:
|
|
194
195
|
def throughput_test_once(
|
195
196
|
backend_name: str,
|
196
197
|
backend,
|
197
|
-
reqs: List[
|
198
|
+
reqs: List[DatasetRow],
|
198
199
|
ignore_eos: bool,
|
199
200
|
extra_request_body: Dict,
|
200
201
|
profile: bool,
|
@@ -203,7 +204,7 @@ def throughput_test_once(
|
|
203
204
|
"backend": backend_name,
|
204
205
|
"successful_requests": len(reqs),
|
205
206
|
"total_latency": -1,
|
206
|
-
"total_input_tokens": sum(r
|
207
|
+
"total_input_tokens": sum(r.prompt_len for r in reqs),
|
207
208
|
"total_output_tokens": -1,
|
208
209
|
"request_throughput": -1,
|
209
210
|
"input_throughput": -1,
|
@@ -211,11 +212,11 @@ def throughput_test_once(
|
|
211
212
|
"total_throughput": -1,
|
212
213
|
}
|
213
214
|
|
214
|
-
prompt = [r
|
215
|
+
prompt = [r.prompt for r in reqs]
|
215
216
|
sampling_params = [
|
216
217
|
{
|
217
218
|
"temperature": 0,
|
218
|
-
"max_new_tokens": r
|
219
|
+
"max_new_tokens": r.output_len,
|
219
220
|
"ignore_eos": ignore_eos,
|
220
221
|
**extra_request_body,
|
221
222
|
}
|
@@ -259,13 +260,14 @@ def throughput_test_once(
|
|
259
260
|
measurement_results["total_input_tokens"]
|
260
261
|
+ measurement_results["total_output_tokens"]
|
261
262
|
) / latency
|
262
|
-
measurement_results["last_gen_throughput"] = server_info["
|
263
|
+
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
264
|
+
"last_gen_throughput"
|
265
|
+
]
|
263
266
|
|
264
267
|
return measurement_results
|
265
268
|
|
266
269
|
|
267
270
|
def monitor_trace_file(directory, interval=1):
|
268
|
-
|
269
271
|
print(f"Monitoring {directory} for new trace files...")
|
270
272
|
|
271
273
|
known_files = set(os.listdir(directory))
|
@@ -315,7 +317,7 @@ def throughput_test(
|
|
315
317
|
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
316
318
|
tokenizer = get_tokenizer(tokenizer_id)
|
317
319
|
|
318
|
-
# Set global
|
320
|
+
# Set global environments
|
319
321
|
set_ulimit()
|
320
322
|
random.seed(bench_args.seed)
|
321
323
|
np.random.seed(bench_args.seed)
|
sglang/bench_one_batch.py
CHANGED
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
|
|
246
246
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
247
247
|
model_worker_batch = batch.get_model_worker_batch()
|
248
248
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
249
|
-
logits_output = model_runner.forward(forward_batch)
|
249
|
+
logits_output, _ = model_runner.forward(forward_batch)
|
250
250
|
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
251
251
|
return next_token_ids, logits_output.next_token_logits, batch
|
252
252
|
|
@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
|
|
258
258
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
259
259
|
model_worker_batch = batch.get_model_worker_batch()
|
260
260
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
261
|
-
logits_output = model_runner.forward(forward_batch)
|
261
|
+
logits_output, _ = model_runner.forward(forward_batch)
|
262
262
|
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
263
263
|
return next_token_ids, logits_output.next_token_logits
|
264
264
|
|
@@ -269,6 +269,7 @@ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
|
269
269
|
batch,
|
270
270
|
dp_size=model_runner.server_args.dp_size,
|
271
271
|
attn_tp_size=1,
|
272
|
+
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
|
272
273
|
tp_cpu_group=model_runner.tp_group.cpu_group,
|
273
274
|
get_idle_batch=None,
|
274
275
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
@@ -372,10 +373,10 @@ def latency_test_run_once(
|
|
372
373
|
|
373
374
|
# Prefill
|
374
375
|
synchronize(device)
|
375
|
-
tic = time.
|
376
|
+
tic = time.perf_counter()
|
376
377
|
next_token_ids, _, batch = extend(reqs, model_runner)
|
377
378
|
synchronize(device)
|
378
|
-
prefill_latency = time.
|
379
|
+
prefill_latency = time.perf_counter() - tic
|
379
380
|
tot_latency += prefill_latency
|
380
381
|
throughput = input_len * batch_size / prefill_latency
|
381
382
|
rank_print(
|
@@ -388,10 +389,10 @@ def latency_test_run_once(
|
|
388
389
|
decode_latencies = []
|
389
390
|
for i in range(output_len - 1):
|
390
391
|
synchronize(device)
|
391
|
-
tic = time.
|
392
|
+
tic = time.perf_counter()
|
392
393
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
393
394
|
synchronize(device)
|
394
|
-
latency = time.
|
395
|
+
latency = time.perf_counter() - tic
|
395
396
|
tot_latency += latency
|
396
397
|
throughput = batch_size / latency
|
397
398
|
decode_latencies.append(latency)
|
sglang/bench_one_batch_server.py
CHANGED
@@ -22,9 +22,11 @@ from typing import Tuple
|
|
22
22
|
import numpy as np
|
23
23
|
import requests
|
24
24
|
|
25
|
+
from sglang.bench_serving import get_tokenizer, sample_random_requests
|
25
26
|
from sglang.srt.entrypoints.http_server import launch_server
|
26
27
|
from sglang.srt.server_args import ServerArgs
|
27
28
|
from sglang.srt.utils import kill_process_tree
|
29
|
+
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
28
30
|
|
29
31
|
|
30
32
|
@dataclasses.dataclass
|
@@ -33,9 +35,13 @@ class BenchArgs:
|
|
33
35
|
batch_size: Tuple[int] = (1,)
|
34
36
|
input_len: Tuple[int] = (1024,)
|
35
37
|
output_len: Tuple[int] = (16,)
|
38
|
+
temperature: float = 0.0
|
39
|
+
return_logprob: bool = False
|
40
|
+
input_len_step_percentage: float = 0.0
|
36
41
|
result_filename: str = "result.jsonl"
|
37
42
|
base_url: str = ""
|
38
43
|
skip_warmup: bool = False
|
44
|
+
show_report: bool = False
|
39
45
|
|
40
46
|
@staticmethod
|
41
47
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -49,11 +55,19 @@ class BenchArgs:
|
|
49
55
|
parser.add_argument(
|
50
56
|
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
51
57
|
)
|
58
|
+
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
59
|
+
parser.add_argument("--return-logprob", action="store_true")
|
60
|
+
parser.add_argument(
|
61
|
+
"--input-len-step-percentage",
|
62
|
+
type=float,
|
63
|
+
default=BenchArgs.input_len_step_percentage,
|
64
|
+
)
|
52
65
|
parser.add_argument(
|
53
66
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
54
67
|
)
|
55
68
|
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
56
69
|
parser.add_argument("--skip-warmup", action="store_true")
|
70
|
+
parser.add_argument("--show-report", action="store_true")
|
57
71
|
|
58
72
|
@classmethod
|
59
73
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -79,8 +93,8 @@ def launch_server_process(server_args: ServerArgs):
|
|
79
93
|
base_url = f"http://{server_args.host}:{server_args.port}"
|
80
94
|
timeout = 600
|
81
95
|
|
82
|
-
start_time = time.
|
83
|
-
while time.
|
96
|
+
start_time = time.perf_counter()
|
97
|
+
while time.perf_counter() - start_time < timeout:
|
84
98
|
try:
|
85
99
|
headers = {
|
86
100
|
"Content-Type": "application/json; charset=utf-8",
|
@@ -99,36 +113,91 @@ def run_one_case(
|
|
99
113
|
batch_size: int,
|
100
114
|
input_len: int,
|
101
115
|
output_len: int,
|
116
|
+
temperature: float,
|
117
|
+
return_logprob: bool,
|
118
|
+
input_len_step_percentage: float,
|
102
119
|
run_name: str,
|
103
120
|
result_filename: str,
|
121
|
+
tokenizer,
|
104
122
|
):
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
123
|
+
requests.post(url + "/flush_cache")
|
124
|
+
input_requests = sample_random_requests(
|
125
|
+
input_len=input_len,
|
126
|
+
output_len=output_len,
|
127
|
+
num_prompts=batch_size,
|
128
|
+
range_ratio=1.0,
|
129
|
+
tokenizer=tokenizer,
|
130
|
+
dataset_path="",
|
131
|
+
random_sample=True,
|
132
|
+
return_text=False,
|
133
|
+
)
|
134
|
+
|
135
|
+
use_structured_outputs = False
|
136
|
+
if use_structured_outputs:
|
137
|
+
texts = []
|
138
|
+
for _ in range(batch_size):
|
139
|
+
texts.append(
|
140
|
+
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
|
141
|
+
* 50
|
142
|
+
+ "Assistant:"
|
143
|
+
)
|
144
|
+
json_schema = "$$ANY$$"
|
145
|
+
else:
|
146
|
+
json_schema = None
|
109
147
|
|
110
|
-
tic = time.
|
148
|
+
tic = time.perf_counter()
|
111
149
|
response = requests.post(
|
112
150
|
url + "/generate",
|
113
151
|
json={
|
114
|
-
"input_ids":
|
152
|
+
"input_ids": [req.prompt for req in input_requests],
|
115
153
|
"sampling_params": {
|
116
|
-
"temperature":
|
154
|
+
"temperature": temperature,
|
117
155
|
"max_new_tokens": output_len,
|
118
156
|
"ignore_eos": True,
|
157
|
+
"json_schema": json_schema,
|
119
158
|
},
|
159
|
+
"return_logprob": return_logprob,
|
160
|
+
"stream": True,
|
120
161
|
},
|
162
|
+
stream=True,
|
121
163
|
)
|
122
|
-
latency = time.time() - tic
|
123
164
|
|
124
|
-
|
125
|
-
|
165
|
+
# The TTFT of the last request in the batch
|
166
|
+
ttft = 0.0
|
167
|
+
for chunk in response.iter_lines(decode_unicode=False):
|
168
|
+
chunk = chunk.decode("utf-8")
|
169
|
+
if chunk and chunk.startswith("data:"):
|
170
|
+
if chunk == "data: [DONE]":
|
171
|
+
break
|
172
|
+
data = json.loads(chunk[5:].strip("\n"))
|
173
|
+
if "error" in data:
|
174
|
+
raise RuntimeError(f"Request has failed. {data}.")
|
175
|
+
|
176
|
+
assert (
|
177
|
+
data["meta_info"]["finish_reason"] is None
|
178
|
+
or data["meta_info"]["finish_reason"]["type"] == "length"
|
179
|
+
)
|
180
|
+
if data["meta_info"]["completion_tokens"] == 1:
|
181
|
+
ttft = time.perf_counter() - tic
|
182
|
+
|
183
|
+
latency = time.perf_counter() - tic
|
184
|
+
input_throughput = batch_size * input_len / ttft
|
185
|
+
output_throughput = batch_size * output_len / (latency - ttft)
|
126
186
|
overall_throughput = batch_size * (input_len + output_len) / latency
|
127
187
|
|
188
|
+
server_info = requests.get(url + "/get_server_info").json()
|
189
|
+
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
|
190
|
+
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
|
191
|
+
|
128
192
|
print(f"batch size: {batch_size}")
|
193
|
+
print(f"input_len: {input_len}")
|
194
|
+
print(f"output_len: {output_len}")
|
129
195
|
print(f"latency: {latency:.2f} s")
|
130
|
-
print(f"
|
131
|
-
print(f"
|
196
|
+
print(f"ttft: {ttft:.2f} s")
|
197
|
+
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
|
198
|
+
print(f"Input throughput: {input_throughput:.2f} tok/s")
|
199
|
+
if output_len != 1:
|
200
|
+
print(f"output throughput: {output_throughput:.2f} tok/s")
|
132
201
|
|
133
202
|
if result_filename:
|
134
203
|
with open(result_filename, "a") as fout:
|
@@ -140,9 +209,21 @@ def run_one_case(
|
|
140
209
|
"latency": round(latency, 4),
|
141
210
|
"output_throughput": round(output_throughput, 2),
|
142
211
|
"overall_throughput": round(overall_throughput, 2),
|
212
|
+
"last_gen_throughput": round(last_gen_throughput, 2),
|
143
213
|
}
|
144
214
|
fout.write(json.dumps(res) + "\n")
|
145
215
|
|
216
|
+
return (
|
217
|
+
batch_size,
|
218
|
+
latency,
|
219
|
+
ttft,
|
220
|
+
input_throughput,
|
221
|
+
output_throughput,
|
222
|
+
overall_throughput,
|
223
|
+
last_gen_throughput,
|
224
|
+
acc_length,
|
225
|
+
)
|
226
|
+
|
146
227
|
|
147
228
|
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
148
229
|
if bench_args.base_url:
|
@@ -150,29 +231,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
150
231
|
else:
|
151
232
|
proc, base_url = launch_server_process(server_args)
|
152
233
|
|
234
|
+
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
235
|
+
tokenizer = get_tokenizer(tokenizer_id)
|
236
|
+
|
153
237
|
# warmup
|
154
238
|
if not bench_args.skip_warmup:
|
239
|
+
print("=" * 8 + " Warmup Begin " + "=" * 8)
|
155
240
|
run_one_case(
|
156
241
|
base_url,
|
157
242
|
batch_size=16,
|
158
243
|
input_len=1024,
|
159
244
|
output_len=16,
|
245
|
+
temperature=bench_args.temperature,
|
246
|
+
return_logprob=bench_args.return_logprob,
|
247
|
+
input_len_step_percentage=bench_args.input_len_step_percentage,
|
160
248
|
run_name="",
|
161
249
|
result_filename="",
|
250
|
+
tokenizer=tokenizer,
|
162
251
|
)
|
252
|
+
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
163
253
|
|
164
254
|
# benchmark
|
255
|
+
result = []
|
165
256
|
try:
|
166
257
|
for bs, il, ol in itertools.product(
|
167
258
|
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
168
259
|
):
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
260
|
+
result.append(
|
261
|
+
run_one_case(
|
262
|
+
base_url,
|
263
|
+
bs,
|
264
|
+
il,
|
265
|
+
ol,
|
266
|
+
temperature=bench_args.temperature,
|
267
|
+
return_logprob=bench_args.return_logprob,
|
268
|
+
input_len_step_percentage=bench_args.input_len_step_percentage,
|
269
|
+
run_name=bench_args.run_name,
|
270
|
+
result_filename=bench_args.result_filename,
|
271
|
+
tokenizer=tokenizer,
|
272
|
+
)
|
176
273
|
)
|
177
274
|
finally:
|
178
275
|
if proc:
|
@@ -180,6 +277,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
180
277
|
|
181
278
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
182
279
|
|
280
|
+
if not bench_args.show_report:
|
281
|
+
return
|
282
|
+
|
283
|
+
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
|
284
|
+
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
|
285
|
+
|
286
|
+
for (
|
287
|
+
batch_size,
|
288
|
+
latency,
|
289
|
+
ttft,
|
290
|
+
input_throughput,
|
291
|
+
output_throughput,
|
292
|
+
overall_throughput,
|
293
|
+
last_gen_throughput,
|
294
|
+
acc_length,
|
295
|
+
) in result:
|
296
|
+
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
|
297
|
+
input_util = 0.7
|
298
|
+
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
|
299
|
+
line = (
|
300
|
+
f"| {batch_size} | "
|
301
|
+
f"{latency:.2f} | "
|
302
|
+
f"{input_throughput:.2f} | "
|
303
|
+
f"{output_throughput:.2f} | "
|
304
|
+
f"{accept_length} | "
|
305
|
+
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
|
306
|
+
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
|
307
|
+
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
|
308
|
+
)
|
309
|
+
summary += line
|
310
|
+
|
311
|
+
# print metrics table
|
312
|
+
print(summary)
|
313
|
+
|
314
|
+
if is_in_ci():
|
315
|
+
write_github_step_summary(
|
316
|
+
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
|
317
|
+
)
|
318
|
+
|
183
319
|
|
184
320
|
if __name__ == "__main__":
|
185
321
|
parser = argparse.ArgumentParser()
|