sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,12 @@ import logging
|
|
17
17
|
import os
|
18
18
|
import random
|
19
19
|
import time
|
20
|
-
from typing import Dict, List, Optional
|
20
|
+
from typing import Dict, List, Optional
|
21
21
|
|
22
22
|
import numpy as np
|
23
23
|
|
24
24
|
from sglang.bench_serving import (
|
25
|
+
DatasetRow,
|
25
26
|
get_dataset,
|
26
27
|
get_tokenizer,
|
27
28
|
sample_random_requests,
|
@@ -194,7 +195,7 @@ class BenchArgs:
|
|
194
195
|
def throughput_test_once(
|
195
196
|
backend_name: str,
|
196
197
|
backend,
|
197
|
-
reqs: List[
|
198
|
+
reqs: List[DatasetRow],
|
198
199
|
ignore_eos: bool,
|
199
200
|
extra_request_body: Dict,
|
200
201
|
profile: bool,
|
@@ -203,7 +204,7 @@ def throughput_test_once(
|
|
203
204
|
"backend": backend_name,
|
204
205
|
"successful_requests": len(reqs),
|
205
206
|
"total_latency": -1,
|
206
|
-
"total_input_tokens": sum(r
|
207
|
+
"total_input_tokens": sum(r.prompt_len for r in reqs),
|
207
208
|
"total_output_tokens": -1,
|
208
209
|
"request_throughput": -1,
|
209
210
|
"input_throughput": -1,
|
@@ -211,11 +212,11 @@ def throughput_test_once(
|
|
211
212
|
"total_throughput": -1,
|
212
213
|
}
|
213
214
|
|
214
|
-
prompt = [r
|
215
|
+
prompt = [r.prompt for r in reqs]
|
215
216
|
sampling_params = [
|
216
217
|
{
|
217
218
|
"temperature": 0,
|
218
|
-
"max_new_tokens": r
|
219
|
+
"max_new_tokens": r.output_len,
|
219
220
|
"ignore_eos": ignore_eos,
|
220
221
|
**extra_request_body,
|
221
222
|
}
|
@@ -267,7 +268,6 @@ def throughput_test_once(
|
|
267
268
|
|
268
269
|
|
269
270
|
def monitor_trace_file(directory, interval=1):
|
270
|
-
|
271
271
|
print(f"Monitoring {directory} for new trace files...")
|
272
272
|
|
273
273
|
known_files = set(os.listdir(directory))
|
sglang/bench_one_batch.py
CHANGED
@@ -269,6 +269,7 @@ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
|
269
269
|
batch,
|
270
270
|
dp_size=model_runner.server_args.dp_size,
|
271
271
|
attn_tp_size=1,
|
272
|
+
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
|
272
273
|
tp_cpu_group=model_runner.tp_group.cpu_group,
|
273
274
|
get_idle_batch=None,
|
274
275
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
@@ -372,10 +373,10 @@ def latency_test_run_once(
|
|
372
373
|
|
373
374
|
# Prefill
|
374
375
|
synchronize(device)
|
375
|
-
tic = time.
|
376
|
+
tic = time.perf_counter()
|
376
377
|
next_token_ids, _, batch = extend(reqs, model_runner)
|
377
378
|
synchronize(device)
|
378
|
-
prefill_latency = time.
|
379
|
+
prefill_latency = time.perf_counter() - tic
|
379
380
|
tot_latency += prefill_latency
|
380
381
|
throughput = input_len * batch_size / prefill_latency
|
381
382
|
rank_print(
|
@@ -388,10 +389,10 @@ def latency_test_run_once(
|
|
388
389
|
decode_latencies = []
|
389
390
|
for i in range(output_len - 1):
|
390
391
|
synchronize(device)
|
391
|
-
tic = time.
|
392
|
+
tic = time.perf_counter()
|
392
393
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
393
394
|
synchronize(device)
|
394
|
-
latency = time.
|
395
|
+
latency = time.perf_counter() - tic
|
395
396
|
tot_latency += latency
|
396
397
|
throughput = batch_size / latency
|
397
398
|
decode_latencies.append(latency)
|
sglang/bench_one_batch_server.py
CHANGED
@@ -22,6 +22,7 @@ from typing import Tuple
|
|
22
22
|
import numpy as np
|
23
23
|
import requests
|
24
24
|
|
25
|
+
from sglang.bench_serving import get_tokenizer, sample_random_requests
|
25
26
|
from sglang.srt.entrypoints.http_server import launch_server
|
26
27
|
from sglang.srt.server_args import ServerArgs
|
27
28
|
from sglang.srt.utils import kill_process_tree
|
@@ -92,8 +93,8 @@ def launch_server_process(server_args: ServerArgs):
|
|
92
93
|
base_url = f"http://{server_args.host}:{server_args.port}"
|
93
94
|
timeout = 600
|
94
95
|
|
95
|
-
start_time = time.
|
96
|
-
while time.
|
96
|
+
start_time = time.perf_counter()
|
97
|
+
while time.perf_counter() - start_time < timeout:
|
97
98
|
try:
|
98
99
|
headers = {
|
99
100
|
"Content-Type": "application/json; charset=utf-8",
|
@@ -117,16 +118,19 @@ def run_one_case(
|
|
117
118
|
input_len_step_percentage: float,
|
118
119
|
run_name: str,
|
119
120
|
result_filename: str,
|
121
|
+
tokenizer,
|
120
122
|
):
|
121
123
|
requests.post(url + "/flush_cache")
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
124
|
+
input_requests = sample_random_requests(
|
125
|
+
input_len=input_len,
|
126
|
+
output_len=output_len,
|
127
|
+
num_prompts=batch_size,
|
128
|
+
range_ratio=1.0,
|
129
|
+
tokenizer=tokenizer,
|
130
|
+
dataset_path="",
|
131
|
+
random_sample=True,
|
132
|
+
return_text=False,
|
133
|
+
)
|
130
134
|
|
131
135
|
use_structured_outputs = False
|
132
136
|
if use_structured_outputs:
|
@@ -141,12 +145,11 @@ def run_one_case(
|
|
141
145
|
else:
|
142
146
|
json_schema = None
|
143
147
|
|
144
|
-
tic = time.
|
148
|
+
tic = time.perf_counter()
|
145
149
|
response = requests.post(
|
146
150
|
url + "/generate",
|
147
151
|
json={
|
148
|
-
|
149
|
-
"input_ids": input_ids,
|
152
|
+
"input_ids": [req.prompt for req in input_requests],
|
150
153
|
"sampling_params": {
|
151
154
|
"temperature": temperature,
|
152
155
|
"max_new_tokens": output_len,
|
@@ -175,9 +178,9 @@ def run_one_case(
|
|
175
178
|
or data["meta_info"]["finish_reason"]["type"] == "length"
|
176
179
|
)
|
177
180
|
if data["meta_info"]["completion_tokens"] == 1:
|
178
|
-
ttft = time.
|
181
|
+
ttft = time.perf_counter() - tic
|
179
182
|
|
180
|
-
latency = time.
|
183
|
+
latency = time.perf_counter() - tic
|
181
184
|
input_throughput = batch_size * input_len / ttft
|
182
185
|
output_throughput = batch_size * output_len / (latency - ttft)
|
183
186
|
overall_throughput = batch_size * (input_len + output_len) / latency
|
@@ -228,6 +231,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
228
231
|
else:
|
229
232
|
proc, base_url = launch_server_process(server_args)
|
230
233
|
|
234
|
+
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
235
|
+
tokenizer = get_tokenizer(tokenizer_id)
|
236
|
+
|
231
237
|
# warmup
|
232
238
|
if not bench_args.skip_warmup:
|
233
239
|
print("=" * 8 + " Warmup Begin " + "=" * 8)
|
@@ -241,6 +247,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
241
247
|
input_len_step_percentage=bench_args.input_len_step_percentage,
|
242
248
|
run_name="",
|
243
249
|
result_filename="",
|
250
|
+
tokenizer=tokenizer,
|
244
251
|
)
|
245
252
|
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
246
253
|
|
@@ -261,6 +268,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
261
268
|
input_len_step_percentage=bench_args.input_len_step_percentage,
|
262
269
|
run_name=bench_args.run_name,
|
263
270
|
result_filename=bench_args.result_filename,
|
271
|
+
tokenizer=tokenizer,
|
264
272
|
)
|
265
273
|
)
|
266
274
|
finally:
|
sglang/bench_serving.py
CHANGED
@@ -24,6 +24,7 @@ import warnings
|
|
24
24
|
from argparse import ArgumentParser
|
25
25
|
from dataclasses import dataclass, field
|
26
26
|
from datetime import datetime
|
27
|
+
from json import JSONDecodeError
|
27
28
|
from pathlib import Path
|
28
29
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
29
30
|
|
@@ -73,6 +74,12 @@ class RequestFuncOutput:
|
|
73
74
|
error: str = ""
|
74
75
|
output_len: int = 0
|
75
76
|
|
77
|
+
@staticmethod
|
78
|
+
def init_new(request_func_input: RequestFuncInput):
|
79
|
+
output = RequestFuncOutput()
|
80
|
+
output.prompt_len = request_func_input.prompt_len
|
81
|
+
return output
|
82
|
+
|
76
83
|
|
77
84
|
def remove_prefix(text: str, prefix: str) -> str:
|
78
85
|
return text[len(prefix) :] if text.startswith(prefix) else text
|
@@ -114,8 +121,7 @@ async def async_request_trt_llm(
|
|
114
121
|
if args.disable_ignore_eos:
|
115
122
|
del payload["min_length"]
|
116
123
|
del payload["end_id"]
|
117
|
-
output = RequestFuncOutput()
|
118
|
-
output.prompt_len = request_func_input.prompt_len
|
124
|
+
output = RequestFuncOutput.init_new(request_func_input)
|
119
125
|
|
120
126
|
ttft = 0.0
|
121
127
|
st = time.perf_counter()
|
@@ -186,8 +192,7 @@ async def async_request_openai_completions(
|
|
186
192
|
}
|
187
193
|
headers = get_auth_headers()
|
188
194
|
|
189
|
-
output = RequestFuncOutput()
|
190
|
-
output.prompt_len = request_func_input.prompt_len
|
195
|
+
output = RequestFuncOutput.init_new(request_func_input)
|
191
196
|
|
192
197
|
generated_text = ""
|
193
198
|
output_len = request_func_input.output_len
|
@@ -269,8 +274,7 @@ async def async_request_truss(
|
|
269
274
|
}
|
270
275
|
headers = get_auth_headers()
|
271
276
|
|
272
|
-
output = RequestFuncOutput()
|
273
|
-
output.prompt_len = request_func_input.prompt_len
|
277
|
+
output = RequestFuncOutput.init_new(request_func_input)
|
274
278
|
|
275
279
|
generated_text = ""
|
276
280
|
ttft = 0.0
|
@@ -355,8 +359,7 @@ async def async_request_sglang_generate(
|
|
355
359
|
|
356
360
|
headers = get_auth_headers()
|
357
361
|
|
358
|
-
output = RequestFuncOutput()
|
359
|
-
output.prompt_len = request_func_input.prompt_len
|
362
|
+
output = RequestFuncOutput.init_new(request_func_input)
|
360
363
|
|
361
364
|
generated_text = ""
|
362
365
|
output_len = request_func_input.output_len
|
@@ -469,6 +472,10 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
|
469
472
|
def get_tokenizer(
|
470
473
|
pretrained_model_name_or_path: str,
|
471
474
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
475
|
+
assert (
|
476
|
+
pretrained_model_name_or_path is not None
|
477
|
+
and pretrained_model_name_or_path != ""
|
478
|
+
)
|
472
479
|
if pretrained_model_name_or_path.endswith(
|
473
480
|
".json"
|
474
481
|
) or pretrained_model_name_or_path.endswith(".model"):
|
@@ -582,7 +589,7 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
|
|
582
589
|
filename = os.path.join("/tmp", url.split("/")[-1])
|
583
590
|
|
584
591
|
# Check if the cache file already exists
|
585
|
-
if
|
592
|
+
if is_file_valid_json(filename):
|
586
593
|
return filename
|
587
594
|
|
588
595
|
print(f"Downloading from {url} to {filename}")
|
@@ -610,12 +617,35 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
|
|
610
617
|
return filename
|
611
618
|
|
612
619
|
|
620
|
+
def is_file_valid_json(path):
|
621
|
+
if not os.path.isfile(path):
|
622
|
+
return False
|
623
|
+
|
624
|
+
# TODO can fuse into the real file open later
|
625
|
+
try:
|
626
|
+
with open(path) as f:
|
627
|
+
json.load(f)
|
628
|
+
return True
|
629
|
+
except JSONDecodeError as e:
|
630
|
+
print(
|
631
|
+
f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
|
632
|
+
)
|
633
|
+
return False
|
634
|
+
|
635
|
+
|
636
|
+
@dataclass
|
637
|
+
class DatasetRow:
|
638
|
+
prompt: str
|
639
|
+
prompt_len: int
|
640
|
+
output_len: int
|
641
|
+
|
642
|
+
|
613
643
|
def sample_mmmu_requests(
|
614
644
|
num_requests: int,
|
615
645
|
tokenizer: PreTrainedTokenizerBase,
|
616
646
|
fixed_output_len: Optional[int] = None,
|
617
647
|
random_sample: bool = True,
|
618
|
-
) -> List[
|
648
|
+
) -> List[DatasetRow]:
|
619
649
|
"""
|
620
650
|
Sample requests from the MMMU dataset using HuggingFace datasets.
|
621
651
|
|
@@ -716,7 +746,11 @@ def sample_mmmu_requests(
|
|
716
746
|
|
717
747
|
output_len = fixed_output_len if fixed_output_len is not None else 256
|
718
748
|
|
719
|
-
filtered_dataset.append(
|
749
|
+
filtered_dataset.append(
|
750
|
+
DatasetRow(
|
751
|
+
prompt=prompt, prompt_len=prompt_len, output_len=output_len
|
752
|
+
)
|
753
|
+
)
|
720
754
|
|
721
755
|
except Exception as e:
|
722
756
|
print(f"Error processing example {i}: {e}")
|
@@ -733,12 +767,12 @@ def sample_sharegpt_requests(
|
|
733
767
|
context_len: Optional[int] = None,
|
734
768
|
prompt_suffix: Optional[str] = "",
|
735
769
|
apply_chat_template=False,
|
736
|
-
) -> List[
|
770
|
+
) -> List[DatasetRow]:
|
737
771
|
if fixed_output_len is not None and fixed_output_len < 4:
|
738
772
|
raise ValueError("output_len too small")
|
739
773
|
|
740
774
|
# Download sharegpt if necessary
|
741
|
-
if not
|
775
|
+
if not is_file_valid_json(dataset_path) and dataset_path == "":
|
742
776
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
743
777
|
|
744
778
|
# Load the dataset.
|
@@ -764,7 +798,7 @@ def sample_sharegpt_requests(
|
|
764
798
|
random.shuffle(dataset)
|
765
799
|
|
766
800
|
# Filter out sequences that are too long or too short
|
767
|
-
filtered_dataset: List[
|
801
|
+
filtered_dataset: List[DatasetRow] = []
|
768
802
|
for i in range(len(dataset)):
|
769
803
|
if len(filtered_dataset) == num_requests:
|
770
804
|
break
|
@@ -802,10 +836,12 @@ def sample_sharegpt_requests(
|
|
802
836
|
# Prune too long sequences.
|
803
837
|
continue
|
804
838
|
|
805
|
-
filtered_dataset.append(
|
839
|
+
filtered_dataset.append(
|
840
|
+
DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len)
|
841
|
+
)
|
806
842
|
|
807
|
-
print(f"#Input tokens: {np.sum([x
|
808
|
-
print(f"#Output tokens: {np.sum([x
|
843
|
+
print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}")
|
844
|
+
print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}")
|
809
845
|
return filtered_dataset
|
810
846
|
|
811
847
|
|
@@ -817,7 +853,8 @@ def sample_random_requests(
|
|
817
853
|
tokenizer: PreTrainedTokenizerBase,
|
818
854
|
dataset_path: str,
|
819
855
|
random_sample: bool = True,
|
820
|
-
|
856
|
+
return_text: bool = True,
|
857
|
+
) -> List[DatasetRow]:
|
821
858
|
input_lens = np.random.randint(
|
822
859
|
max(int(input_len * range_ratio), 1),
|
823
860
|
input_len + 1,
|
@@ -833,7 +870,7 @@ def sample_random_requests(
|
|
833
870
|
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
834
871
|
|
835
872
|
# Download sharegpt if necessary
|
836
|
-
if not
|
873
|
+
if not is_file_valid_json(dataset_path):
|
837
874
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
838
875
|
|
839
876
|
# Load the dataset.
|
@@ -857,7 +894,7 @@ def sample_random_requests(
|
|
857
894
|
random.shuffle(dataset)
|
858
895
|
|
859
896
|
# Filter out sequences that are too long or too short
|
860
|
-
input_requests: List[
|
897
|
+
input_requests: List[DatasetRow] = []
|
861
898
|
for data in dataset:
|
862
899
|
i = len(input_requests)
|
863
900
|
if i == num_prompts:
|
@@ -877,20 +914,34 @@ def sample_random_requests(
|
|
877
914
|
else:
|
878
915
|
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
|
879
916
|
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
|
880
|
-
|
881
|
-
|
917
|
+
input_content = input_ids
|
918
|
+
if return_text:
|
919
|
+
input_content = tokenizer.decode(input_content)
|
920
|
+
input_requests.append(
|
921
|
+
DatasetRow(
|
922
|
+
prompt=input_content,
|
923
|
+
prompt_len=int(input_lens[i]),
|
924
|
+
output_len=int(output_lens[i]),
|
925
|
+
)
|
926
|
+
)
|
882
927
|
else:
|
883
928
|
# Sample token ids from random integers. This can cause some NaN issues.
|
884
929
|
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
885
930
|
input_requests = []
|
886
931
|
for i in range(num_prompts):
|
887
|
-
|
888
|
-
[
|
889
|
-
|
890
|
-
|
891
|
-
|
932
|
+
input_content = [
|
933
|
+
(offsets[i] + i + j) % tokenizer.vocab_size
|
934
|
+
for j in range(input_lens[i])
|
935
|
+
]
|
936
|
+
if return_text:
|
937
|
+
input_content = tokenizer.decode(input_content)
|
938
|
+
input_requests.append(
|
939
|
+
DatasetRow(
|
940
|
+
prompt=input_content,
|
941
|
+
prompt_len=int(input_lens[i]),
|
942
|
+
output_len=int(output_lens[i]),
|
943
|
+
)
|
892
944
|
)
|
893
|
-
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
|
894
945
|
|
895
946
|
print(f"#Input tokens: {np.sum(input_lens)}")
|
896
947
|
print(f"#Output tokens: {np.sum(output_lens)}")
|
@@ -925,7 +976,7 @@ def sample_generated_shared_prefix_requests(
|
|
925
976
|
output_len: int,
|
926
977
|
tokenizer: PreTrainedTokenizerBase,
|
927
978
|
args: argparse.Namespace,
|
928
|
-
) -> List[
|
979
|
+
) -> List[DatasetRow]:
|
929
980
|
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
|
930
981
|
cache_path = get_gen_prefix_cache_path(args, tokenizer)
|
931
982
|
|
@@ -963,7 +1014,11 @@ def sample_generated_shared_prefix_requests(
|
|
963
1014
|
full_prompt = f"{system_prompt}\n\n{question}"
|
964
1015
|
prompt_len = len(tokenizer.encode(full_prompt))
|
965
1016
|
|
966
|
-
input_requests.append(
|
1017
|
+
input_requests.append(
|
1018
|
+
DatasetRow(
|
1019
|
+
prompt=full_prompt, prompt_len=prompt_len, output_len=output_len
|
1020
|
+
)
|
1021
|
+
)
|
967
1022
|
total_input_tokens += prompt_len
|
968
1023
|
total_output_tokens += output_len
|
969
1024
|
|
@@ -994,9 +1049,9 @@ def sample_generated_shared_prefix_requests(
|
|
994
1049
|
|
995
1050
|
|
996
1051
|
async def get_request(
|
997
|
-
input_requests: List[
|
1052
|
+
input_requests: List[DatasetRow],
|
998
1053
|
request_rate: float,
|
999
|
-
) -> AsyncGenerator[
|
1054
|
+
) -> AsyncGenerator[DatasetRow, None]:
|
1000
1055
|
input_requests = iter(input_requests)
|
1001
1056
|
for request in input_requests:
|
1002
1057
|
yield request
|
@@ -1012,7 +1067,7 @@ async def get_request(
|
|
1012
1067
|
|
1013
1068
|
|
1014
1069
|
def calculate_metrics(
|
1015
|
-
input_requests: List[
|
1070
|
+
input_requests: List[DatasetRow],
|
1016
1071
|
outputs: List[RequestFuncOutput],
|
1017
1072
|
dur_s: float,
|
1018
1073
|
tokenizer: PreTrainedTokenizerBase,
|
@@ -1034,7 +1089,7 @@ def calculate_metrics(
|
|
1034
1089
|
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
|
1035
1090
|
)
|
1036
1091
|
retokenized_output_lens.append(retokenized_output_len)
|
1037
|
-
total_input += input_requests[i]
|
1092
|
+
total_input += input_requests[i].prompt_len
|
1038
1093
|
if output_len > 1:
|
1039
1094
|
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
1040
1095
|
itls += outputs[i].itl
|
@@ -1096,7 +1151,7 @@ async def benchmark(
|
|
1096
1151
|
base_url: str,
|
1097
1152
|
model_id: str,
|
1098
1153
|
tokenizer: PreTrainedTokenizerBase,
|
1099
|
-
input_requests: List[
|
1154
|
+
input_requests: List[DatasetRow],
|
1100
1155
|
request_rate: float,
|
1101
1156
|
max_concurrency: Optional[int],
|
1102
1157
|
disable_tqdm: bool,
|
@@ -1126,7 +1181,12 @@ async def benchmark(
|
|
1126
1181
|
print(f"Starting warmup with {warmup_requests} sequences...")
|
1127
1182
|
|
1128
1183
|
# Use the first request for all warmup iterations
|
1129
|
-
|
1184
|
+
test_request = input_requests[0]
|
1185
|
+
test_prompt, test_prompt_len, test_output_len = (
|
1186
|
+
test_request.prompt,
|
1187
|
+
test_request.prompt_len,
|
1188
|
+
test_request.output_len,
|
1189
|
+
)
|
1130
1190
|
if lora_names is not None and len(lora_names) != 0:
|
1131
1191
|
lora_name = lora_names[0]
|
1132
1192
|
else:
|
@@ -1194,7 +1254,11 @@ async def benchmark(
|
|
1194
1254
|
benchmark_start_time = time.perf_counter()
|
1195
1255
|
tasks: List[asyncio.Task] = []
|
1196
1256
|
async for request in get_request(input_requests, request_rate):
|
1197
|
-
prompt, prompt_len, output_len =
|
1257
|
+
prompt, prompt_len, output_len = (
|
1258
|
+
request.prompt,
|
1259
|
+
request.prompt_len,
|
1260
|
+
request.output_len,
|
1261
|
+
)
|
1198
1262
|
if lora_names is not None and len(lora_names) != 0:
|
1199
1263
|
idx = random.randint(0, len(lora_names) - 1)
|
1200
1264
|
lora_name = lora_names[idx]
|
@@ -1239,14 +1303,17 @@ async def benchmark(
|
|
1239
1303
|
|
1240
1304
|
if "sglang" in backend:
|
1241
1305
|
server_info = requests.get(base_url + "/get_server_info")
|
1242
|
-
if
|
1243
|
-
|
1244
|
-
"
|
1245
|
-
|
1306
|
+
if server_info.status_code == 200:
|
1307
|
+
if pd_separated:
|
1308
|
+
accept_length = server_info.json()["decode"][0]["internal_states"][
|
1309
|
+
0
|
1310
|
+
].get("avg_spec_accept_length", None)
|
1311
|
+
else:
|
1312
|
+
accept_length = server_info.json()["internal_states"][0].get(
|
1313
|
+
"avg_spec_accept_length", None
|
1314
|
+
)
|
1246
1315
|
else:
|
1247
|
-
accept_length =
|
1248
|
-
"avg_spec_accept_length", None
|
1249
|
-
)
|
1316
|
+
accept_length = None
|
1250
1317
|
else:
|
1251
1318
|
accept_length = None
|
1252
1319
|
|
@@ -1380,21 +1447,24 @@ async def benchmark(
|
|
1380
1447
|
else:
|
1381
1448
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
1382
1449
|
|
1450
|
+
result_details = {
|
1451
|
+
"input_lens": [output.prompt_len for output in outputs],
|
1452
|
+
"output_lens": output_lens,
|
1453
|
+
"ttfts": [output.ttft for output in outputs],
|
1454
|
+
"itls": [output.itl for output in outputs],
|
1455
|
+
"generated_texts": [output.generated_text for output in outputs],
|
1456
|
+
"errors": [output.error for output in outputs],
|
1457
|
+
}
|
1458
|
+
|
1383
1459
|
# Append results to a JSONL file
|
1384
1460
|
with open(output_file_name, "a") as file:
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
"itls": [output.itl for output in outputs],
|
1393
|
-
"generated_texts": [output.generated_text for output in outputs],
|
1394
|
-
"errors": [output.error for output in outputs],
|
1395
|
-
}
|
1396
|
-
)
|
1397
|
-
return result
|
1461
|
+
if args.output_details:
|
1462
|
+
result_for_dump = result | result_details
|
1463
|
+
else:
|
1464
|
+
result_for_dump = result
|
1465
|
+
file.write(json.dumps(result_for_dump) + "\n")
|
1466
|
+
|
1467
|
+
return result | result_details
|
1398
1468
|
|
1399
1469
|
|
1400
1470
|
def check_chat_template(model_path):
|
@@ -1424,6 +1494,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1424
1494
|
if not hasattr(args, "warmup_requests"):
|
1425
1495
|
args.warmup_requests = 1
|
1426
1496
|
|
1497
|
+
if not hasattr(args, "output_details"):
|
1498
|
+
args.output_details = False
|
1499
|
+
|
1427
1500
|
print(f"benchmark_args={args}")
|
1428
1501
|
|
1429
1502
|
# Set global environments
|
@@ -1668,6 +1741,9 @@ if __name__ == "__main__":
|
|
1668
1741
|
"if the server is not processing requests fast enough to keep up.",
|
1669
1742
|
)
|
1670
1743
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
1744
|
+
parser.add_argument(
|
1745
|
+
"--output-details", action="store_true", help="Output details of benchmarking."
|
1746
|
+
)
|
1671
1747
|
parser.add_argument(
|
1672
1748
|
"--disable-tqdm",
|
1673
1749
|
action="store_true",
|
sglang/compile_deep_gemm.py
CHANGED
@@ -82,8 +82,8 @@ def launch_server_process_and_send_one_request(
|
|
82
82
|
base_url = f"http://{server_args.host}:{server_args.port}"
|
83
83
|
timeout = compile_args.timeout
|
84
84
|
|
85
|
-
start_time = time.
|
86
|
-
while time.
|
85
|
+
start_time = time.perf_counter()
|
86
|
+
while time.perf_counter() - start_time < timeout:
|
87
87
|
try:
|
88
88
|
headers = {
|
89
89
|
"Content-Type": "application/json; charset=utf-8",
|
@@ -112,9 +112,9 @@ def launch_server_process_and_send_one_request(
|
|
112
112
|
raise RuntimeError(f"Sync request failed: {error}")
|
113
113
|
# Other nodes should wait for the exit signal from Rank-0 node.
|
114
114
|
else:
|
115
|
-
start_time_waiting = time.
|
115
|
+
start_time_waiting = time.perf_counter()
|
116
116
|
while proc.is_alive():
|
117
|
-
if time.
|
117
|
+
if time.perf_counter() - start_time_waiting < timeout:
|
118
118
|
time.sleep(10)
|
119
119
|
else:
|
120
120
|
raise TimeoutError("Waiting for main node timeout!")
|