sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -259,7 +259,9 @@ def throughput_test_once(
|
|
259
259
|
measurement_results["total_input_tokens"]
|
260
260
|
+ measurement_results["total_output_tokens"]
|
261
261
|
) / latency
|
262
|
-
measurement_results["last_gen_throughput"] = server_info["
|
262
|
+
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
263
|
+
"last_gen_throughput"
|
264
|
+
]
|
263
265
|
|
264
266
|
return measurement_results
|
265
267
|
|
@@ -315,7 +317,7 @@ def throughput_test(
|
|
315
317
|
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
316
318
|
tokenizer = get_tokenizer(tokenizer_id)
|
317
319
|
|
318
|
-
# Set global
|
320
|
+
# Set global environments
|
319
321
|
set_ulimit()
|
320
322
|
random.seed(bench_args.seed)
|
321
323
|
np.random.seed(bench_args.seed)
|
sglang/bench_one_batch.py
CHANGED
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
|
|
246
246
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
247
247
|
model_worker_batch = batch.get_model_worker_batch()
|
248
248
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
249
|
-
logits_output = model_runner.forward(forward_batch)
|
249
|
+
logits_output, _ = model_runner.forward(forward_batch)
|
250
250
|
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
251
251
|
return next_token_ids, logits_output.next_token_logits, batch
|
252
252
|
|
@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
|
|
258
258
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
259
259
|
model_worker_batch = batch.get_model_worker_batch()
|
260
260
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
261
|
-
logits_output = model_runner.forward(forward_batch)
|
261
|
+
logits_output, _ = model_runner.forward(forward_batch)
|
262
262
|
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
263
263
|
return next_token_ids, logits_output.next_token_logits
|
264
264
|
|
sglang/bench_one_batch_server.py
CHANGED
@@ -25,6 +25,7 @@ import requests
|
|
25
25
|
from sglang.srt.entrypoints.http_server import launch_server
|
26
26
|
from sglang.srt.server_args import ServerArgs
|
27
27
|
from sglang.srt.utils import kill_process_tree
|
28
|
+
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
28
29
|
|
29
30
|
|
30
31
|
@dataclasses.dataclass
|
@@ -33,9 +34,13 @@ class BenchArgs:
|
|
33
34
|
batch_size: Tuple[int] = (1,)
|
34
35
|
input_len: Tuple[int] = (1024,)
|
35
36
|
output_len: Tuple[int] = (16,)
|
37
|
+
temperature: float = 0.0
|
38
|
+
return_logprob: bool = False
|
39
|
+
input_len_step_percentage: float = 0.0
|
36
40
|
result_filename: str = "result.jsonl"
|
37
41
|
base_url: str = ""
|
38
42
|
skip_warmup: bool = False
|
43
|
+
show_report: bool = False
|
39
44
|
|
40
45
|
@staticmethod
|
41
46
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -49,11 +54,19 @@ class BenchArgs:
|
|
49
54
|
parser.add_argument(
|
50
55
|
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
51
56
|
)
|
57
|
+
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
58
|
+
parser.add_argument("--return-logprob", action="store_true")
|
59
|
+
parser.add_argument(
|
60
|
+
"--input-len-step-percentage",
|
61
|
+
type=float,
|
62
|
+
default=BenchArgs.input_len_step_percentage,
|
63
|
+
)
|
52
64
|
parser.add_argument(
|
53
65
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
54
66
|
)
|
55
67
|
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
56
68
|
parser.add_argument("--skip-warmup", action="store_true")
|
69
|
+
parser.add_argument("--show-report", action="store_true")
|
57
70
|
|
58
71
|
@classmethod
|
59
72
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -99,36 +112,89 @@ def run_one_case(
|
|
99
112
|
batch_size: int,
|
100
113
|
input_len: int,
|
101
114
|
output_len: int,
|
115
|
+
temperature: float,
|
116
|
+
return_logprob: bool,
|
117
|
+
input_len_step_percentage: float,
|
102
118
|
run_name: str,
|
103
119
|
result_filename: str,
|
104
120
|
):
|
121
|
+
requests.post(url + "/flush_cache")
|
122
|
+
input_lens = [
|
123
|
+
int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
|
124
|
+
for i in range(batch_size)
|
125
|
+
]
|
105
126
|
input_ids = [
|
106
|
-
[int(x) for x in np.random.randint(0, high=16384, size=(
|
107
|
-
for
|
127
|
+
[int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
|
128
|
+
for i in range(batch_size)
|
108
129
|
]
|
109
130
|
|
131
|
+
use_structured_outputs = False
|
132
|
+
if use_structured_outputs:
|
133
|
+
texts = []
|
134
|
+
for _ in range(batch_size):
|
135
|
+
texts.append(
|
136
|
+
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
|
137
|
+
* 50
|
138
|
+
+ "Assistant:"
|
139
|
+
)
|
140
|
+
json_schema = "$$ANY$$"
|
141
|
+
else:
|
142
|
+
json_schema = None
|
143
|
+
|
110
144
|
tic = time.time()
|
111
145
|
response = requests.post(
|
112
146
|
url + "/generate",
|
113
147
|
json={
|
148
|
+
# "text": texts,
|
114
149
|
"input_ids": input_ids,
|
115
150
|
"sampling_params": {
|
116
|
-
"temperature":
|
151
|
+
"temperature": temperature,
|
117
152
|
"max_new_tokens": output_len,
|
118
153
|
"ignore_eos": True,
|
154
|
+
"json_schema": json_schema,
|
119
155
|
},
|
156
|
+
"return_logprob": return_logprob,
|
157
|
+
"stream": True,
|
120
158
|
},
|
159
|
+
stream=True,
|
121
160
|
)
|
122
|
-
latency = time.time() - tic
|
123
161
|
|
124
|
-
|
125
|
-
|
162
|
+
# The TTFT of the last request in the batch
|
163
|
+
ttft = 0.0
|
164
|
+
for chunk in response.iter_lines(decode_unicode=False):
|
165
|
+
chunk = chunk.decode("utf-8")
|
166
|
+
if chunk and chunk.startswith("data:"):
|
167
|
+
if chunk == "data: [DONE]":
|
168
|
+
break
|
169
|
+
data = json.loads(chunk[5:].strip("\n"))
|
170
|
+
if "error" in data:
|
171
|
+
raise RuntimeError(f"Request has failed. {data}.")
|
172
|
+
|
173
|
+
assert (
|
174
|
+
data["meta_info"]["finish_reason"] is None
|
175
|
+
or data["meta_info"]["finish_reason"]["type"] == "length"
|
176
|
+
)
|
177
|
+
if data["meta_info"]["completion_tokens"] == 1:
|
178
|
+
ttft = time.time() - tic
|
179
|
+
|
180
|
+
latency = time.time() - tic
|
181
|
+
input_throughput = batch_size * input_len / ttft
|
182
|
+
output_throughput = batch_size * output_len / (latency - ttft)
|
126
183
|
overall_throughput = batch_size * (input_len + output_len) / latency
|
127
184
|
|
185
|
+
server_info = requests.get(url + "/get_server_info").json()
|
186
|
+
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
|
187
|
+
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
|
188
|
+
|
128
189
|
print(f"batch size: {batch_size}")
|
190
|
+
print(f"input_len: {input_len}")
|
191
|
+
print(f"output_len: {output_len}")
|
129
192
|
print(f"latency: {latency:.2f} s")
|
130
|
-
print(f"
|
131
|
-
print(f"
|
193
|
+
print(f"ttft: {ttft:.2f} s")
|
194
|
+
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
|
195
|
+
print(f"Input throughput: {input_throughput:.2f} tok/s")
|
196
|
+
if output_len != 1:
|
197
|
+
print(f"output throughput: {output_throughput:.2f} tok/s")
|
132
198
|
|
133
199
|
if result_filename:
|
134
200
|
with open(result_filename, "a") as fout:
|
@@ -140,9 +206,21 @@ def run_one_case(
|
|
140
206
|
"latency": round(latency, 4),
|
141
207
|
"output_throughput": round(output_throughput, 2),
|
142
208
|
"overall_throughput": round(overall_throughput, 2),
|
209
|
+
"last_gen_throughput": round(last_gen_throughput, 2),
|
143
210
|
}
|
144
211
|
fout.write(json.dumps(res) + "\n")
|
145
212
|
|
213
|
+
return (
|
214
|
+
batch_size,
|
215
|
+
latency,
|
216
|
+
ttft,
|
217
|
+
input_throughput,
|
218
|
+
output_throughput,
|
219
|
+
overall_throughput,
|
220
|
+
last_gen_throughput,
|
221
|
+
acc_length,
|
222
|
+
)
|
223
|
+
|
146
224
|
|
147
225
|
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
148
226
|
if bench_args.base_url:
|
@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
152
230
|
|
153
231
|
# warmup
|
154
232
|
if not bench_args.skip_warmup:
|
233
|
+
print("=" * 8 + " Warmup Begin " + "=" * 8)
|
155
234
|
run_one_case(
|
156
235
|
base_url,
|
157
236
|
batch_size=16,
|
158
237
|
input_len=1024,
|
159
238
|
output_len=16,
|
239
|
+
temperature=bench_args.temperature,
|
240
|
+
return_logprob=bench_args.return_logprob,
|
241
|
+
input_len_step_percentage=bench_args.input_len_step_percentage,
|
160
242
|
run_name="",
|
161
243
|
result_filename="",
|
162
244
|
)
|
245
|
+
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
163
246
|
|
164
247
|
# benchmark
|
248
|
+
result = []
|
165
249
|
try:
|
166
250
|
for bs, il, ol in itertools.product(
|
167
251
|
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
168
252
|
):
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
253
|
+
result.append(
|
254
|
+
run_one_case(
|
255
|
+
base_url,
|
256
|
+
bs,
|
257
|
+
il,
|
258
|
+
ol,
|
259
|
+
temperature=bench_args.temperature,
|
260
|
+
return_logprob=bench_args.return_logprob,
|
261
|
+
input_len_step_percentage=bench_args.input_len_step_percentage,
|
262
|
+
run_name=bench_args.run_name,
|
263
|
+
result_filename=bench_args.result_filename,
|
264
|
+
)
|
176
265
|
)
|
177
266
|
finally:
|
178
267
|
if proc:
|
@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
180
269
|
|
181
270
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
182
271
|
|
272
|
+
if not bench_args.show_report:
|
273
|
+
return
|
274
|
+
|
275
|
+
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
|
276
|
+
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
|
277
|
+
|
278
|
+
for (
|
279
|
+
batch_size,
|
280
|
+
latency,
|
281
|
+
ttft,
|
282
|
+
input_throughput,
|
283
|
+
output_throughput,
|
284
|
+
overall_throughput,
|
285
|
+
last_gen_throughput,
|
286
|
+
acc_length,
|
287
|
+
) in result:
|
288
|
+
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
|
289
|
+
input_util = 0.7
|
290
|
+
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
|
291
|
+
line = (
|
292
|
+
f"| {batch_size} | "
|
293
|
+
f"{latency:.2f} | "
|
294
|
+
f"{input_throughput:.2f} | "
|
295
|
+
f"{output_throughput:.2f} | "
|
296
|
+
f"{accept_length} | "
|
297
|
+
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
|
298
|
+
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
|
299
|
+
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
|
300
|
+
)
|
301
|
+
summary += line
|
302
|
+
|
303
|
+
# print metrics table
|
304
|
+
print(summary)
|
305
|
+
|
306
|
+
if is_in_ci():
|
307
|
+
write_github_step_summary(
|
308
|
+
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
|
309
|
+
)
|
310
|
+
|
183
311
|
|
184
312
|
if __name__ == "__main__":
|
185
313
|
parser = argparse.ArgumentParser()
|
sglang/bench_serving.py
CHANGED
@@ -1103,7 +1103,7 @@ async def benchmark(
|
|
1103
1103
|
lora_names: List[str],
|
1104
1104
|
extra_request_body: Dict[str, Any],
|
1105
1105
|
profile: bool,
|
1106
|
-
|
1106
|
+
pd_separated: bool = False,
|
1107
1107
|
flush_cache: bool = False,
|
1108
1108
|
warmup_requests: int = 1,
|
1109
1109
|
):
|
@@ -1239,12 +1239,14 @@ async def benchmark(
|
|
1239
1239
|
|
1240
1240
|
if "sglang" in backend:
|
1241
1241
|
server_info = requests.get(base_url + "/get_server_info")
|
1242
|
-
if
|
1243
|
-
accept_length = server_info.json()["decode"][0].get(
|
1242
|
+
if pd_separated:
|
1243
|
+
accept_length = server_info.json()["decode"][0]["internal_states"][0].get(
|
1244
1244
|
"avg_spec_accept_length", None
|
1245
1245
|
)
|
1246
1246
|
else:
|
1247
|
-
accept_length = server_info.json().get(
|
1247
|
+
accept_length = server_info.json()["internal_states"][0].get(
|
1248
|
+
"avg_spec_accept_length", None
|
1249
|
+
)
|
1248
1250
|
else:
|
1249
1251
|
accept_length = None
|
1250
1252
|
|
@@ -1263,7 +1265,7 @@ async def benchmark(
|
|
1263
1265
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
1264
1266
|
print(
|
1265
1267
|
"{:<40} {:<10}".format(
|
1266
|
-
"Max
|
1268
|
+
"Max request concurrency:",
|
1267
1269
|
max_concurrency if max_concurrency else "not set",
|
1268
1270
|
)
|
1269
1271
|
)
|
@@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1541
1543
|
lora_names=args.lora_name,
|
1542
1544
|
extra_request_body=extra_request_body,
|
1543
1545
|
profile=args.profile,
|
1544
|
-
|
1546
|
+
pd_separated=args.pd_separated,
|
1545
1547
|
flush_cache=args.flush_cache,
|
1546
1548
|
)
|
1547
1549
|
)
|
@@ -1720,7 +1722,7 @@ if __name__ == "__main__":
|
|
1720
1722
|
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
1721
1723
|
)
|
1722
1724
|
parser.add_argument(
|
1723
|
-
"--pd-
|
1725
|
+
"--pd-separated",
|
1724
1726
|
action="store_true",
|
1725
1727
|
help="Benchmark PD disaggregation server",
|
1726
1728
|
)
|
sglang/compile_deep_gemm.py
CHANGED
@@ -129,7 +129,7 @@ def launch_server_process_and_send_one_request(
|
|
129
129
|
|
130
130
|
|
131
131
|
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
132
|
-
#
|
132
|
+
# Disable cuda graph and torch compile to save time
|
133
133
|
server_args.disable_cuda_graph = True
|
134
134
|
server_args.enable_torch_compile = False
|
135
135
|
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
@@ -0,0 +1,157 @@
|
|
1
|
+
import argparse
|
2
|
+
import asyncio
|
3
|
+
import os
|
4
|
+
import pickle
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import List
|
7
|
+
|
8
|
+
import openai
|
9
|
+
import torch
|
10
|
+
from bert_score import BERTScorer
|
11
|
+
from datasets import load_dataset
|
12
|
+
from tqdm import tqdm
|
13
|
+
|
14
|
+
|
15
|
+
def get_client(api_url: str) -> openai.AsyncOpenAI:
|
16
|
+
if os.getenv("OPENAI_API_KEY") is None:
|
17
|
+
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
18
|
+
return openai.AsyncOpenAI(base_url=api_url)
|
19
|
+
|
20
|
+
|
21
|
+
def get_dataset():
|
22
|
+
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
|
23
|
+
|
24
|
+
|
25
|
+
async def fetch_response(
|
26
|
+
client: openai.AsyncOpenAI,
|
27
|
+
context: str,
|
28
|
+
question: str,
|
29
|
+
semaphore: asyncio.Semaphore,
|
30
|
+
index: int,
|
31
|
+
model: str,
|
32
|
+
output_dir: Path,
|
33
|
+
):
|
34
|
+
output_file = output_dir / f"response_{index}.pkl"
|
35
|
+
if output_file.exists():
|
36
|
+
return
|
37
|
+
|
38
|
+
prompt = (
|
39
|
+
"Please answer the question based on the long texts below.\n"
|
40
|
+
f"{context}\n"
|
41
|
+
f"Question: {question}\n"
|
42
|
+
"Answer:"
|
43
|
+
)
|
44
|
+
messages = [
|
45
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
46
|
+
{"role": "user", "content": prompt},
|
47
|
+
]
|
48
|
+
|
49
|
+
async with semaphore:
|
50
|
+
try:
|
51
|
+
response = await client.chat.completions.create(
|
52
|
+
model=model,
|
53
|
+
messages=messages,
|
54
|
+
temperature=0.0,
|
55
|
+
max_tokens=512,
|
56
|
+
)
|
57
|
+
except openai.BadRequestError as e:
|
58
|
+
with open(output_file, "wb") as f:
|
59
|
+
pickle.dump({"error": str(e)}, f)
|
60
|
+
return
|
61
|
+
|
62
|
+
with open(output_file, "wb") as f:
|
63
|
+
pickle.dump(response, f)
|
64
|
+
|
65
|
+
|
66
|
+
async def benchmark(args):
|
67
|
+
dataset = get_dataset()
|
68
|
+
output_dir = Path(args.output_dir)
|
69
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
70
|
+
|
71
|
+
client = get_client(args.api_url)
|
72
|
+
semaphore = asyncio.Semaphore(args.max_concurrency)
|
73
|
+
|
74
|
+
tasks: List[asyncio.Task] = []
|
75
|
+
for idx, ex in enumerate(dataset):
|
76
|
+
tasks.append(
|
77
|
+
asyncio.create_task(
|
78
|
+
fetch_response(
|
79
|
+
client,
|
80
|
+
ex["context"],
|
81
|
+
ex["question"],
|
82
|
+
semaphore,
|
83
|
+
idx,
|
84
|
+
args.model,
|
85
|
+
output_dir,
|
86
|
+
)
|
87
|
+
)
|
88
|
+
)
|
89
|
+
|
90
|
+
for _ in tqdm(
|
91
|
+
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
|
92
|
+
):
|
93
|
+
await _
|
94
|
+
|
95
|
+
|
96
|
+
def analyse(args):
|
97
|
+
dataset = get_dataset()
|
98
|
+
output_dir = Path(args.output_dir)
|
99
|
+
|
100
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
101
|
+
scorer = BERTScorer(lang="en", device=device)
|
102
|
+
|
103
|
+
hyps: List[str] = []
|
104
|
+
refs: List[str] = []
|
105
|
+
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
|
106
|
+
pkl_file = output_dir / f"response_{idx}.pkl"
|
107
|
+
if not pkl_file.exists():
|
108
|
+
raise FileNotFoundError(pkl_file)
|
109
|
+
|
110
|
+
response = pickle.load(open(pkl_file, "rb"))
|
111
|
+
if isinstance(response, dict) and "error" in response:
|
112
|
+
continue
|
113
|
+
|
114
|
+
hyps.append(response.choices[0].message.content.strip())
|
115
|
+
refs.append(ex["answer"])
|
116
|
+
|
117
|
+
if not hyps:
|
118
|
+
print("No valid responses to score!")
|
119
|
+
return
|
120
|
+
|
121
|
+
batch_size = 64
|
122
|
+
all_f1: List[float] = []
|
123
|
+
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
|
124
|
+
h_batch = hyps[i : i + batch_size]
|
125
|
+
r_batch = refs[i : i + batch_size]
|
126
|
+
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
|
127
|
+
all_f1.extend([float(x) for x in f1_scores])
|
128
|
+
|
129
|
+
avg = sum(all_f1) / len(all_f1)
|
130
|
+
print(f"Average BERTScore (F1): {avg:.2%}")
|
131
|
+
|
132
|
+
|
133
|
+
if __name__ == "__main__":
|
134
|
+
parser = argparse.ArgumentParser(
|
135
|
+
description="Run benchmark and evaluation in one go."
|
136
|
+
)
|
137
|
+
parser.add_argument(
|
138
|
+
"--api-url",
|
139
|
+
default="http://127.0.0.1:30000/v1",
|
140
|
+
help="OpenAI‑compatible API base URL",
|
141
|
+
)
|
142
|
+
parser.add_argument(
|
143
|
+
"--model",
|
144
|
+
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
145
|
+
help="Model name or ID, only used for model name",
|
146
|
+
)
|
147
|
+
parser.add_argument(
|
148
|
+
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
|
149
|
+
)
|
150
|
+
parser.add_argument(
|
151
|
+
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
|
152
|
+
)
|
153
|
+
args = parser.parse_args()
|
154
|
+
|
155
|
+
asyncio.run(benchmark(args))
|
156
|
+
|
157
|
+
analyse(args)
|