sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -259,7 +259,9 @@ def throughput_test_once(
259
259
  measurement_results["total_input_tokens"]
260
260
  + measurement_results["total_output_tokens"]
261
261
  ) / latency
262
- measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
262
+ measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
263
+ "last_gen_throughput"
264
+ ]
263
265
 
264
266
  return measurement_results
265
267
 
@@ -315,7 +317,7 @@ def throughput_test(
315
317
  tokenizer_id = server_args.tokenizer_path or server_args.model_path
316
318
  tokenizer = get_tokenizer(tokenizer_id)
317
319
 
318
- # Set global environmnets
320
+ # Set global environments
319
321
  set_ulimit()
320
322
  random.seed(bench_args.seed)
321
323
  np.random.seed(bench_args.seed)
sglang/bench_one_batch.py CHANGED
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
246
246
  _maybe_prepare_dp_attn_batch(batch, model_runner)
247
247
  model_worker_batch = batch.get_model_worker_batch()
248
248
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
249
- logits_output = model_runner.forward(forward_batch)
249
+ logits_output, _ = model_runner.forward(forward_batch)
250
250
  next_token_ids = model_runner.sample(logits_output, forward_batch)
251
251
  return next_token_ids, logits_output.next_token_logits, batch
252
252
 
@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
258
258
  _maybe_prepare_dp_attn_batch(batch, model_runner)
259
259
  model_worker_batch = batch.get_model_worker_batch()
260
260
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
261
- logits_output = model_runner.forward(forward_batch)
261
+ logits_output, _ = model_runner.forward(forward_batch)
262
262
  next_token_ids = model_runner.sample(logits_output, forward_batch)
263
263
  return next_token_ids, logits_output.next_token_logits
264
264
 
@@ -25,6 +25,7 @@ import requests
25
25
  from sglang.srt.entrypoints.http_server import launch_server
26
26
  from sglang.srt.server_args import ServerArgs
27
27
  from sglang.srt.utils import kill_process_tree
28
+ from sglang.test.test_utils import is_in_ci, write_github_step_summary
28
29
 
29
30
 
30
31
  @dataclasses.dataclass
@@ -33,9 +34,13 @@ class BenchArgs:
33
34
  batch_size: Tuple[int] = (1,)
34
35
  input_len: Tuple[int] = (1024,)
35
36
  output_len: Tuple[int] = (16,)
37
+ temperature: float = 0.0
38
+ return_logprob: bool = False
39
+ input_len_step_percentage: float = 0.0
36
40
  result_filename: str = "result.jsonl"
37
41
  base_url: str = ""
38
42
  skip_warmup: bool = False
43
+ show_report: bool = False
39
44
 
40
45
  @staticmethod
41
46
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -49,11 +54,19 @@ class BenchArgs:
49
54
  parser.add_argument(
50
55
  "--output-len", type=int, nargs="+", default=BenchArgs.output_len
51
56
  )
57
+ parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
58
+ parser.add_argument("--return-logprob", action="store_true")
59
+ parser.add_argument(
60
+ "--input-len-step-percentage",
61
+ type=float,
62
+ default=BenchArgs.input_len_step_percentage,
63
+ )
52
64
  parser.add_argument(
53
65
  "--result-filename", type=str, default=BenchArgs.result_filename
54
66
  )
55
67
  parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
56
68
  parser.add_argument("--skip-warmup", action="store_true")
69
+ parser.add_argument("--show-report", action="store_true")
57
70
 
58
71
  @classmethod
59
72
  def from_cli_args(cls, args: argparse.Namespace):
@@ -99,36 +112,89 @@ def run_one_case(
99
112
  batch_size: int,
100
113
  input_len: int,
101
114
  output_len: int,
115
+ temperature: float,
116
+ return_logprob: bool,
117
+ input_len_step_percentage: float,
102
118
  run_name: str,
103
119
  result_filename: str,
104
120
  ):
121
+ requests.post(url + "/flush_cache")
122
+ input_lens = [
123
+ int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
124
+ for i in range(batch_size)
125
+ ]
105
126
  input_ids = [
106
- [int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
107
- for _ in range(batch_size)
127
+ [int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
128
+ for i in range(batch_size)
108
129
  ]
109
130
 
131
+ use_structured_outputs = False
132
+ if use_structured_outputs:
133
+ texts = []
134
+ for _ in range(batch_size):
135
+ texts.append(
136
+ "Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
137
+ * 50
138
+ + "Assistant:"
139
+ )
140
+ json_schema = "$$ANY$$"
141
+ else:
142
+ json_schema = None
143
+
110
144
  tic = time.time()
111
145
  response = requests.post(
112
146
  url + "/generate",
113
147
  json={
148
+ # "text": texts,
114
149
  "input_ids": input_ids,
115
150
  "sampling_params": {
116
- "temperature": 0,
151
+ "temperature": temperature,
117
152
  "max_new_tokens": output_len,
118
153
  "ignore_eos": True,
154
+ "json_schema": json_schema,
119
155
  },
156
+ "return_logprob": return_logprob,
157
+ "stream": True,
120
158
  },
159
+ stream=True,
121
160
  )
122
- latency = time.time() - tic
123
161
 
124
- _ = response.json()
125
- output_throughput = batch_size * output_len / latency
162
+ # The TTFT of the last request in the batch
163
+ ttft = 0.0
164
+ for chunk in response.iter_lines(decode_unicode=False):
165
+ chunk = chunk.decode("utf-8")
166
+ if chunk and chunk.startswith("data:"):
167
+ if chunk == "data: [DONE]":
168
+ break
169
+ data = json.loads(chunk[5:].strip("\n"))
170
+ if "error" in data:
171
+ raise RuntimeError(f"Request has failed. {data}.")
172
+
173
+ assert (
174
+ data["meta_info"]["finish_reason"] is None
175
+ or data["meta_info"]["finish_reason"]["type"] == "length"
176
+ )
177
+ if data["meta_info"]["completion_tokens"] == 1:
178
+ ttft = time.time() - tic
179
+
180
+ latency = time.time() - tic
181
+ input_throughput = batch_size * input_len / ttft
182
+ output_throughput = batch_size * output_len / (latency - ttft)
126
183
  overall_throughput = batch_size * (input_len + output_len) / latency
127
184
 
185
+ server_info = requests.get(url + "/get_server_info").json()
186
+ acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
187
+ last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
188
+
128
189
  print(f"batch size: {batch_size}")
190
+ print(f"input_len: {input_len}")
191
+ print(f"output_len: {output_len}")
129
192
  print(f"latency: {latency:.2f} s")
130
- print(f"output throughput: {output_throughput:.2f} token/s")
131
- print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
193
+ print(f"ttft: {ttft:.2f} s")
194
+ print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
195
+ print(f"Input throughput: {input_throughput:.2f} tok/s")
196
+ if output_len != 1:
197
+ print(f"output throughput: {output_throughput:.2f} tok/s")
132
198
 
133
199
  if result_filename:
134
200
  with open(result_filename, "a") as fout:
@@ -140,9 +206,21 @@ def run_one_case(
140
206
  "latency": round(latency, 4),
141
207
  "output_throughput": round(output_throughput, 2),
142
208
  "overall_throughput": round(overall_throughput, 2),
209
+ "last_gen_throughput": round(last_gen_throughput, 2),
143
210
  }
144
211
  fout.write(json.dumps(res) + "\n")
145
212
 
213
+ return (
214
+ batch_size,
215
+ latency,
216
+ ttft,
217
+ input_throughput,
218
+ output_throughput,
219
+ overall_throughput,
220
+ last_gen_throughput,
221
+ acc_length,
222
+ )
223
+
146
224
 
147
225
  def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
148
226
  if bench_args.base_url:
@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
152
230
 
153
231
  # warmup
154
232
  if not bench_args.skip_warmup:
233
+ print("=" * 8 + " Warmup Begin " + "=" * 8)
155
234
  run_one_case(
156
235
  base_url,
157
236
  batch_size=16,
158
237
  input_len=1024,
159
238
  output_len=16,
239
+ temperature=bench_args.temperature,
240
+ return_logprob=bench_args.return_logprob,
241
+ input_len_step_percentage=bench_args.input_len_step_percentage,
160
242
  run_name="",
161
243
  result_filename="",
162
244
  )
245
+ print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
163
246
 
164
247
  # benchmark
248
+ result = []
165
249
  try:
166
250
  for bs, il, ol in itertools.product(
167
251
  bench_args.batch_size, bench_args.input_len, bench_args.output_len
168
252
  ):
169
- run_one_case(
170
- base_url,
171
- bs,
172
- il,
173
- ol,
174
- bench_args.run_name,
175
- bench_args.result_filename,
253
+ result.append(
254
+ run_one_case(
255
+ base_url,
256
+ bs,
257
+ il,
258
+ ol,
259
+ temperature=bench_args.temperature,
260
+ return_logprob=bench_args.return_logprob,
261
+ input_len_step_percentage=bench_args.input_len_step_percentage,
262
+ run_name=bench_args.run_name,
263
+ result_filename=bench_args.result_filename,
264
+ )
176
265
  )
177
266
  finally:
178
267
  if proc:
@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
180
269
 
181
270
  print(f"\nResults are saved to {bench_args.result_filename}")
182
271
 
272
+ if not bench_args.show_report:
273
+ return
274
+
275
+ summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
276
+ summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
277
+
278
+ for (
279
+ batch_size,
280
+ latency,
281
+ ttft,
282
+ input_throughput,
283
+ output_throughput,
284
+ overall_throughput,
285
+ last_gen_throughput,
286
+ acc_length,
287
+ ) in result:
288
+ hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
289
+ input_util = 0.7
290
+ accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
291
+ line = (
292
+ f"| {batch_size} | "
293
+ f"{latency:.2f} | "
294
+ f"{input_throughput:.2f} | "
295
+ f"{output_throughput:.2f} | "
296
+ f"{accept_length} | "
297
+ f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
298
+ f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
299
+ f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
300
+ )
301
+ summary += line
302
+
303
+ # print metrics table
304
+ print(summary)
305
+
306
+ if is_in_ci():
307
+ write_github_step_summary(
308
+ f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
309
+ )
310
+
183
311
 
184
312
  if __name__ == "__main__":
185
313
  parser = argparse.ArgumentParser()
sglang/bench_serving.py CHANGED
@@ -1103,7 +1103,7 @@ async def benchmark(
1103
1103
  lora_names: List[str],
1104
1104
  extra_request_body: Dict[str, Any],
1105
1105
  profile: bool,
1106
- pd_seperated: bool = False,
1106
+ pd_separated: bool = False,
1107
1107
  flush_cache: bool = False,
1108
1108
  warmup_requests: int = 1,
1109
1109
  ):
@@ -1239,12 +1239,14 @@ async def benchmark(
1239
1239
 
1240
1240
  if "sglang" in backend:
1241
1241
  server_info = requests.get(base_url + "/get_server_info")
1242
- if pd_seperated:
1243
- accept_length = server_info.json()["decode"][0].get(
1242
+ if pd_separated:
1243
+ accept_length = server_info.json()["decode"][0]["internal_states"][0].get(
1244
1244
  "avg_spec_accept_length", None
1245
1245
  )
1246
1246
  else:
1247
- accept_length = server_info.json().get("avg_spec_accept_length", None)
1247
+ accept_length = server_info.json()["internal_states"][0].get(
1248
+ "avg_spec_accept_length", None
1249
+ )
1248
1250
  else:
1249
1251
  accept_length = None
1250
1252
 
@@ -1263,7 +1265,7 @@ async def benchmark(
1263
1265
  print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
1264
1266
  print(
1265
1267
  "{:<40} {:<10}".format(
1266
- "Max reqeuest concurrency:",
1268
+ "Max request concurrency:",
1267
1269
  max_concurrency if max_concurrency else "not set",
1268
1270
  )
1269
1271
  )
@@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
1541
1543
  lora_names=args.lora_name,
1542
1544
  extra_request_body=extra_request_body,
1543
1545
  profile=args.profile,
1544
- pd_seperated=args.pd_seperated,
1546
+ pd_separated=args.pd_separated,
1545
1547
  flush_cache=args.flush_cache,
1546
1548
  )
1547
1549
  )
@@ -1720,7 +1722,7 @@ if __name__ == "__main__":
1720
1722
  help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
1721
1723
  )
1722
1724
  parser.add_argument(
1723
- "--pd-seperated",
1725
+ "--pd-separated",
1724
1726
  action="store_true",
1725
1727
  help="Benchmark PD disaggregation server",
1726
1728
  )
@@ -129,7 +129,7 @@ def launch_server_process_and_send_one_request(
129
129
 
130
130
 
131
131
  def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
132
- # Disbale cuda graph and torch compile to save time
132
+ # Disable cuda graph and torch compile to save time
133
133
  server_args.disable_cuda_graph = True
134
134
  server_args.enable_torch_compile = False
135
135
  print(f"Disable CUDA Graph and Torch Compile to save time...")
@@ -0,0 +1,157 @@
1
+ import argparse
2
+ import asyncio
3
+ import os
4
+ import pickle
5
+ from pathlib import Path
6
+ from typing import List
7
+
8
+ import openai
9
+ import torch
10
+ from bert_score import BERTScorer
11
+ from datasets import load_dataset
12
+ from tqdm import tqdm
13
+
14
+
15
+ def get_client(api_url: str) -> openai.AsyncOpenAI:
16
+ if os.getenv("OPENAI_API_KEY") is None:
17
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
18
+ return openai.AsyncOpenAI(base_url=api_url)
19
+
20
+
21
+ def get_dataset():
22
+ return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
23
+
24
+
25
+ async def fetch_response(
26
+ client: openai.AsyncOpenAI,
27
+ context: str,
28
+ question: str,
29
+ semaphore: asyncio.Semaphore,
30
+ index: int,
31
+ model: str,
32
+ output_dir: Path,
33
+ ):
34
+ output_file = output_dir / f"response_{index}.pkl"
35
+ if output_file.exists():
36
+ return
37
+
38
+ prompt = (
39
+ "Please answer the question based on the long texts below.\n"
40
+ f"{context}\n"
41
+ f"Question: {question}\n"
42
+ "Answer:"
43
+ )
44
+ messages = [
45
+ {"role": "system", "content": "You are a helpful assistant."},
46
+ {"role": "user", "content": prompt},
47
+ ]
48
+
49
+ async with semaphore:
50
+ try:
51
+ response = await client.chat.completions.create(
52
+ model=model,
53
+ messages=messages,
54
+ temperature=0.0,
55
+ max_tokens=512,
56
+ )
57
+ except openai.BadRequestError as e:
58
+ with open(output_file, "wb") as f:
59
+ pickle.dump({"error": str(e)}, f)
60
+ return
61
+
62
+ with open(output_file, "wb") as f:
63
+ pickle.dump(response, f)
64
+
65
+
66
+ async def benchmark(args):
67
+ dataset = get_dataset()
68
+ output_dir = Path(args.output_dir)
69
+ output_dir.mkdir(parents=True, exist_ok=True)
70
+
71
+ client = get_client(args.api_url)
72
+ semaphore = asyncio.Semaphore(args.max_concurrency)
73
+
74
+ tasks: List[asyncio.Task] = []
75
+ for idx, ex in enumerate(dataset):
76
+ tasks.append(
77
+ asyncio.create_task(
78
+ fetch_response(
79
+ client,
80
+ ex["context"],
81
+ ex["question"],
82
+ semaphore,
83
+ idx,
84
+ args.model,
85
+ output_dir,
86
+ )
87
+ )
88
+ )
89
+
90
+ for _ in tqdm(
91
+ asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
92
+ ):
93
+ await _
94
+
95
+
96
+ def analyse(args):
97
+ dataset = get_dataset()
98
+ output_dir = Path(args.output_dir)
99
+
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ scorer = BERTScorer(lang="en", device=device)
102
+
103
+ hyps: List[str] = []
104
+ refs: List[str] = []
105
+ for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
106
+ pkl_file = output_dir / f"response_{idx}.pkl"
107
+ if not pkl_file.exists():
108
+ raise FileNotFoundError(pkl_file)
109
+
110
+ response = pickle.load(open(pkl_file, "rb"))
111
+ if isinstance(response, dict) and "error" in response:
112
+ continue
113
+
114
+ hyps.append(response.choices[0].message.content.strip())
115
+ refs.append(ex["answer"])
116
+
117
+ if not hyps:
118
+ print("No valid responses to score!")
119
+ return
120
+
121
+ batch_size = 64
122
+ all_f1: List[float] = []
123
+ for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
124
+ h_batch = hyps[i : i + batch_size]
125
+ r_batch = refs[i : i + batch_size]
126
+ _, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
127
+ all_f1.extend([float(x) for x in f1_scores])
128
+
129
+ avg = sum(all_f1) / len(all_f1)
130
+ print(f"Average BERTScore (F1): {avg:.2%}")
131
+
132
+
133
+ if __name__ == "__main__":
134
+ parser = argparse.ArgumentParser(
135
+ description="Run benchmark and evaluation in one go."
136
+ )
137
+ parser.add_argument(
138
+ "--api-url",
139
+ default="http://127.0.0.1:30000/v1",
140
+ help="OpenAI‑compatible API base URL",
141
+ )
142
+ parser.add_argument(
143
+ "--model",
144
+ default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
145
+ help="Model name or ID, only used for model name",
146
+ )
147
+ parser.add_argument(
148
+ "--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
149
+ )
150
+ parser.add_argument(
151
+ "--output-dir", default="tmp-output-dir", help="Directory for cached responses"
152
+ )
153
+ args = parser.parse_args()
154
+
155
+ asyncio.run(benchmark(args))
156
+
157
+ analyse(args)