isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,565 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Section 5.3.2: Scaling Analysis
|
|
4
|
+
|
|
5
|
+
测试方法在不同规模下的性能变化。
|
|
6
|
+
|
|
7
|
+
分析内容:
|
|
8
|
+
1. Tool Set Size Scaling - 工具数量对准确率的影响
|
|
9
|
+
2. LLM Size Scaling - 模型大小对性能的影响
|
|
10
|
+
|
|
11
|
+
输出:
|
|
12
|
+
- figures/fig5_analysis_scaling_tool_count.pdf
|
|
13
|
+
- figures/fig6_analysis_scaling_llm_size.pdf
|
|
14
|
+
- tables/table_scaling_results.tex
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
python exp_analysis_scaling.py
|
|
18
|
+
python exp_analysis_scaling.py --tool-scaling
|
|
19
|
+
python exp_analysis_scaling.py --llm-scaling
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import argparse
|
|
25
|
+
import time
|
|
26
|
+
from typing import Any, Optional
|
|
27
|
+
|
|
28
|
+
from .exp_utils import (
|
|
29
|
+
get_figures_dir,
|
|
30
|
+
load_benchmark_data,
|
|
31
|
+
print_section_header,
|
|
32
|
+
print_subsection_header,
|
|
33
|
+
save_results,
|
|
34
|
+
setup_experiment_env,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# =============================================================================
|
|
38
|
+
# Tool Set Size Scaling
|
|
39
|
+
# =============================================================================
|
|
40
|
+
|
|
41
|
+
TOOL_SCALE_POINTS = [10, 25, 50, 100, 200, 500]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def run_tool_scaling_experiment(
|
|
45
|
+
max_samples: int = 50,
|
|
46
|
+
strategies: Optional[list[str]] = None,
|
|
47
|
+
verbose: bool = True,
|
|
48
|
+
) -> dict[str, list[tuple[int, float, float]]]:
|
|
49
|
+
"""
|
|
50
|
+
运行工具数量 Scaling 实验。
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
max_samples: 每个规模点的最大测试样本数
|
|
54
|
+
strategies: 要测试的策略列表
|
|
55
|
+
verbose: 是否打印详细信息
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
{strategy: [(tool_count, accuracy, latency_ms), ...]}
|
|
59
|
+
"""
|
|
60
|
+
print_subsection_header("Tool Set Size Scaling")
|
|
61
|
+
|
|
62
|
+
if strategies is None:
|
|
63
|
+
strategies = ["selector.keyword", "selector.embedding", "selector.hybrid"]
|
|
64
|
+
|
|
65
|
+
# 加载基础数据
|
|
66
|
+
samples = load_benchmark_data("selection", split="test", max_samples=max_samples)
|
|
67
|
+
if not samples:
|
|
68
|
+
print(" ❌ No selection data available")
|
|
69
|
+
return {}
|
|
70
|
+
|
|
71
|
+
# 加载或生成 noise tools
|
|
72
|
+
noise_tools = _generate_noise_tools(1000)
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
from sage.benchmark.benchmark_agent import get_adapter_registry
|
|
76
|
+
|
|
77
|
+
registry = get_adapter_registry()
|
|
78
|
+
except ImportError:
|
|
79
|
+
print(" ❌ Failed to import adapter registry")
|
|
80
|
+
return {}
|
|
81
|
+
|
|
82
|
+
results: dict[str, list[tuple[int, float, float]]] = {s: [] for s in strategies}
|
|
83
|
+
|
|
84
|
+
for scale in TOOL_SCALE_POINTS:
|
|
85
|
+
print(f"\n Scale: {scale} tools")
|
|
86
|
+
|
|
87
|
+
for strategy_name in strategies:
|
|
88
|
+
try:
|
|
89
|
+
selector = registry.get(strategy_name)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
print(f" ⚠️ {strategy_name}: {e}")
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
# 运行测试
|
|
95
|
+
hits = 0
|
|
96
|
+
total_time = 0.0
|
|
97
|
+
|
|
98
|
+
for sample in samples:
|
|
99
|
+
query = sample.get("instruction", "")
|
|
100
|
+
base_tools = sample.get("candidate_tools", [])
|
|
101
|
+
ground_truth = sample.get("ground_truth", [])
|
|
102
|
+
|
|
103
|
+
# 扩展工具集到目标规模
|
|
104
|
+
extended_tools = _extend_tool_set(base_tools, noise_tools, scale)
|
|
105
|
+
|
|
106
|
+
start = time.time()
|
|
107
|
+
try:
|
|
108
|
+
predictions = selector.select(query, candidate_tools=extended_tools, top_k=5)
|
|
109
|
+
pred_ids = (
|
|
110
|
+
[p.tool_id if hasattr(p, "tool_id") else str(p) for p in predictions]
|
|
111
|
+
if predictions
|
|
112
|
+
else []
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
ref_set = (
|
|
116
|
+
set(ground_truth) if isinstance(ground_truth, list) else {ground_truth}
|
|
117
|
+
)
|
|
118
|
+
if set(pred_ids[:5]) & ref_set:
|
|
119
|
+
hits += 1
|
|
120
|
+
except Exception:
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
total_time += time.time() - start
|
|
124
|
+
|
|
125
|
+
accuracy = hits / len(samples) if samples else 0
|
|
126
|
+
avg_latency = total_time * 1000 / len(samples) if samples else 0
|
|
127
|
+
|
|
128
|
+
results[strategy_name].append((scale, accuracy, avg_latency))
|
|
129
|
+
|
|
130
|
+
if verbose:
|
|
131
|
+
print(
|
|
132
|
+
f" {strategy_name.split('.')[-1]:12s}: {accuracy * 100:5.1f}% ({avg_latency:.1f}ms)"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return results
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _generate_noise_tools(count: int) -> list[str]:
|
|
139
|
+
"""生成干扰工具。"""
|
|
140
|
+
import random
|
|
141
|
+
|
|
142
|
+
categories = ["search", "calendar", "email", "file", "math", "weather", "news", "social"]
|
|
143
|
+
actions = ["get", "set", "create", "delete", "update", "list", "find", "check"]
|
|
144
|
+
|
|
145
|
+
tools = []
|
|
146
|
+
for i in range(count):
|
|
147
|
+
cat = random.choice(categories)
|
|
148
|
+
action = random.choice(actions)
|
|
149
|
+
tools.append(f"noise_{cat}_{action}_{i:04d}")
|
|
150
|
+
|
|
151
|
+
return tools
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _extend_tool_set(base_tools: list[str], noise_tools: list[str], target_size: int) -> list[str]:
|
|
155
|
+
"""扩展工具集到目标大小。"""
|
|
156
|
+
import random
|
|
157
|
+
|
|
158
|
+
result = list(base_tools)
|
|
159
|
+
needed = target_size - len(result)
|
|
160
|
+
|
|
161
|
+
if needed > 0:
|
|
162
|
+
available_noise = [t for t in noise_tools if t not in result]
|
|
163
|
+
result.extend(random.sample(available_noise, min(needed, len(available_noise))))
|
|
164
|
+
|
|
165
|
+
return result[:target_size]
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# =============================================================================
|
|
169
|
+
# LLM Size Scaling
|
|
170
|
+
# =============================================================================
|
|
171
|
+
|
|
172
|
+
# 模型列表 (按参数规模排序)
|
|
173
|
+
# 2x A100 80GB 可支持到 14B 模型单卡运行
|
|
174
|
+
LLM_MODELS = [
|
|
175
|
+
# 小模型
|
|
176
|
+
("Qwen/Qwen2.5-0.5B-Instruct", "0.5B", 1), # ~1GB VRAM
|
|
177
|
+
("Qwen/Qwen2.5-1.5B-Instruct", "1.5B", 1), # ~3GB VRAM
|
|
178
|
+
("Qwen/Qwen2.5-3B-Instruct", "3B", 1), # ~6GB VRAM
|
|
179
|
+
# 中等模型
|
|
180
|
+
("Qwen/Qwen2.5-7B-Instruct", "7B", 1), # ~14GB VRAM
|
|
181
|
+
("Qwen/Qwen2.5-14B-Instruct", "14B", 1), # ~28GB VRAM
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
# 备选:如果 Qwen 下载慢,可用 Llama
|
|
185
|
+
LLM_MODELS_LLAMA = [
|
|
186
|
+
("meta-llama/Llama-3.2-1B-Instruct", "1B", 1),
|
|
187
|
+
("meta-llama/Llama-3.2-3B-Instruct", "3B", 1),
|
|
188
|
+
("meta-llama/Llama-3.1-8B-Instruct", "8B", 1),
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _start_vllm_server(model_id: str, tensor_parallel: int = 1, port: int = 8901) -> bool:
|
|
193
|
+
"""
|
|
194
|
+
启动 vLLM 服务器。
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
model_id: HuggingFace 模型 ID
|
|
198
|
+
tensor_parallel: Tensor Parallel 数量 (GPU 数)
|
|
199
|
+
port: 服务端口
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
是否成功启动
|
|
203
|
+
"""
|
|
204
|
+
import subprocess
|
|
205
|
+
import time
|
|
206
|
+
|
|
207
|
+
import requests # type: ignore[import-untyped]
|
|
208
|
+
|
|
209
|
+
# 先停止已有服务
|
|
210
|
+
_stop_vllm_server(port)
|
|
211
|
+
time.sleep(2)
|
|
212
|
+
|
|
213
|
+
print(f" Starting vLLM server for {model_id}...")
|
|
214
|
+
|
|
215
|
+
cmd = [
|
|
216
|
+
"vllm",
|
|
217
|
+
"serve",
|
|
218
|
+
model_id,
|
|
219
|
+
"--port",
|
|
220
|
+
str(port),
|
|
221
|
+
"--gpu-memory-utilization",
|
|
222
|
+
"0.85",
|
|
223
|
+
"--max-model-len",
|
|
224
|
+
"4096",
|
|
225
|
+
"--trust-remote-code",
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
if tensor_parallel > 1:
|
|
229
|
+
cmd.extend(["--tensor-parallel-size", str(tensor_parallel)])
|
|
230
|
+
|
|
231
|
+
# 后台启动
|
|
232
|
+
try:
|
|
233
|
+
subprocess.Popen(
|
|
234
|
+
cmd,
|
|
235
|
+
stdout=subprocess.PIPE,
|
|
236
|
+
stderr=subprocess.PIPE,
|
|
237
|
+
start_new_session=True,
|
|
238
|
+
)
|
|
239
|
+
except Exception as e:
|
|
240
|
+
print(f" Failed to start vLLM: {e}")
|
|
241
|
+
return False
|
|
242
|
+
|
|
243
|
+
# 等待服务就绪 (最多 5 分钟)
|
|
244
|
+
max_wait = 300
|
|
245
|
+
start_time = time.time()
|
|
246
|
+
|
|
247
|
+
while time.time() - start_time < max_wait:
|
|
248
|
+
try:
|
|
249
|
+
resp = requests.get(f"http://localhost:{port}/v1/models", timeout=5)
|
|
250
|
+
if resp.status_code == 200:
|
|
251
|
+
print(f" vLLM server ready (took {time.time() - start_time:.0f}s)")
|
|
252
|
+
return True
|
|
253
|
+
except Exception:
|
|
254
|
+
pass
|
|
255
|
+
time.sleep(5)
|
|
256
|
+
|
|
257
|
+
print(" Timeout waiting for vLLM server")
|
|
258
|
+
return False
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _stop_vllm_server(port: int = 8901) -> None:
|
|
262
|
+
"""停止 vLLM 服务器。"""
|
|
263
|
+
import subprocess
|
|
264
|
+
|
|
265
|
+
# 通过端口找进程并杀掉
|
|
266
|
+
try:
|
|
267
|
+
result = subprocess.run(
|
|
268
|
+
["lsof", "-t", f"-i:{port}"],
|
|
269
|
+
capture_output=True,
|
|
270
|
+
text=True,
|
|
271
|
+
)
|
|
272
|
+
if result.stdout.strip():
|
|
273
|
+
pids = result.stdout.strip().split("\n")
|
|
274
|
+
for pid in pids:
|
|
275
|
+
subprocess.run(["kill", "-9", pid], capture_output=True)
|
|
276
|
+
except Exception:
|
|
277
|
+
pass
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _test_llm_on_challenge(
|
|
281
|
+
model_id: str,
|
|
282
|
+
samples: list[dict],
|
|
283
|
+
challenge: str,
|
|
284
|
+
port: int = 8901,
|
|
285
|
+
) -> tuple[float, float]:
|
|
286
|
+
"""
|
|
287
|
+
在指定 Challenge 上测试 LLM 性能。
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
(accuracy, avg_latency_ms)
|
|
291
|
+
"""
|
|
292
|
+
import time
|
|
293
|
+
|
|
294
|
+
import requests
|
|
295
|
+
|
|
296
|
+
base_url = f"http://localhost:{port}/v1"
|
|
297
|
+
|
|
298
|
+
correct = 0
|
|
299
|
+
total_latency = 0.0
|
|
300
|
+
tested = 0
|
|
301
|
+
|
|
302
|
+
for sample in samples:
|
|
303
|
+
query = sample.get("instruction", sample.get("query", ""))
|
|
304
|
+
ground_truth = sample.get("ground_truth", sample.get("label", ""))
|
|
305
|
+
|
|
306
|
+
# 构建 prompt
|
|
307
|
+
if challenge == "timing":
|
|
308
|
+
prompt = f"""判断以下用户请求是否需要调用外部工具。只回答 "yes" 或 "no"。
|
|
309
|
+
|
|
310
|
+
用户请求: {query}
|
|
311
|
+
|
|
312
|
+
回答:"""
|
|
313
|
+
elif challenge == "planning":
|
|
314
|
+
tools = sample.get("candidate_tools", [])
|
|
315
|
+
tool_str = ", ".join(tools[:10]) if tools else "search, calculate, weather"
|
|
316
|
+
prompt = f"""根据用户请求,生成工具调用计划。可用工具: {tool_str}
|
|
317
|
+
|
|
318
|
+
用户请求: {query}
|
|
319
|
+
|
|
320
|
+
计划 (JSON 格式):"""
|
|
321
|
+
else: # selection
|
|
322
|
+
tools = sample.get("candidate_tools", [])
|
|
323
|
+
tool_str = "\n".join([f"- {t}" for t in tools[:20]])
|
|
324
|
+
prompt = f"""从以下工具中选择最适合的工具来完成用户请求。
|
|
325
|
+
|
|
326
|
+
可用工具:
|
|
327
|
+
{tool_str}
|
|
328
|
+
|
|
329
|
+
用户请求: {query}
|
|
330
|
+
|
|
331
|
+
选择的工具 (只输出工具名):"""
|
|
332
|
+
|
|
333
|
+
try:
|
|
334
|
+
start = time.time()
|
|
335
|
+
resp = requests.post(
|
|
336
|
+
f"{base_url}/completions",
|
|
337
|
+
json={
|
|
338
|
+
"model": model_id,
|
|
339
|
+
"prompt": prompt,
|
|
340
|
+
"max_tokens": 256,
|
|
341
|
+
"temperature": 0.1,
|
|
342
|
+
},
|
|
343
|
+
timeout=60,
|
|
344
|
+
)
|
|
345
|
+
latency = (time.time() - start) * 1000
|
|
346
|
+
|
|
347
|
+
if resp.status_code == 200:
|
|
348
|
+
result = resp.json()
|
|
349
|
+
output = result.get("choices", [{}])[0].get("text", "").strip().lower()
|
|
350
|
+
|
|
351
|
+
# 简单评估
|
|
352
|
+
if challenge == "timing":
|
|
353
|
+
pred = "yes" in output or "需要" in output
|
|
354
|
+
label = (
|
|
355
|
+
ground_truth
|
|
356
|
+
if isinstance(ground_truth, bool)
|
|
357
|
+
else str(ground_truth).lower() in ["yes", "true", "1"]
|
|
358
|
+
)
|
|
359
|
+
if pred == label:
|
|
360
|
+
correct += 1
|
|
361
|
+
elif challenge == "planning":
|
|
362
|
+
# 检查是否包含正确工具
|
|
363
|
+
if isinstance(ground_truth, list):
|
|
364
|
+
if any(t.lower() in output for t in ground_truth):
|
|
365
|
+
correct += 1
|
|
366
|
+
elif str(ground_truth).lower() in output:
|
|
367
|
+
correct += 1
|
|
368
|
+
else: # selection
|
|
369
|
+
if isinstance(ground_truth, list):
|
|
370
|
+
if any(t.lower() in output for t in ground_truth):
|
|
371
|
+
correct += 1
|
|
372
|
+
elif str(ground_truth).lower() in output:
|
|
373
|
+
correct += 1
|
|
374
|
+
|
|
375
|
+
total_latency += latency
|
|
376
|
+
tested += 1
|
|
377
|
+
|
|
378
|
+
except Exception as e:
|
|
379
|
+
print(f" Error: {e}")
|
|
380
|
+
continue
|
|
381
|
+
|
|
382
|
+
accuracy = correct / tested if tested > 0 else 0.0
|
|
383
|
+
avg_latency = total_latency / tested if tested > 0 else 0.0
|
|
384
|
+
|
|
385
|
+
return accuracy, avg_latency
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def run_llm_scaling_experiment(
|
|
389
|
+
max_samples: int = 30,
|
|
390
|
+
challenge: str = "planning",
|
|
391
|
+
models: Optional[list[Any]] = None,
|
|
392
|
+
verbose: bool = True,
|
|
393
|
+
) -> dict[str, list[tuple[str, float, float]]]:
|
|
394
|
+
"""
|
|
395
|
+
运行 LLM 大小 Scaling 实验 (真实测试)。
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
max_samples: 最大测试样本数
|
|
399
|
+
challenge: 测试的 Challenge (planning 最能体现差异)
|
|
400
|
+
models: 要测试的模型列表 (None 使用默认)
|
|
401
|
+
verbose: 是否打印详细信息
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
{metric: [(model_size, value, latency_ms), ...]}
|
|
405
|
+
"""
|
|
406
|
+
print_subsection_header("LLM Size Scaling (Real Test)")
|
|
407
|
+
|
|
408
|
+
if models is None:
|
|
409
|
+
models = LLM_MODELS
|
|
410
|
+
|
|
411
|
+
samples = load_benchmark_data(challenge, split="test", max_samples=max_samples)
|
|
412
|
+
if not samples:
|
|
413
|
+
print(f" No {challenge} data available")
|
|
414
|
+
return {}
|
|
415
|
+
|
|
416
|
+
print(f" Testing {len(models)} models on {len(samples)} samples")
|
|
417
|
+
print(f" Challenge: {challenge}")
|
|
418
|
+
|
|
419
|
+
results: dict[str, list[tuple[str, float, float]]] = {"accuracy": [], "latency": []}
|
|
420
|
+
port = 8901
|
|
421
|
+
|
|
422
|
+
for model_id, model_size, tensor_parallel in models:
|
|
423
|
+
print(f"\n [{model_size}] {model_id}")
|
|
424
|
+
|
|
425
|
+
# 启动 vLLM 服务
|
|
426
|
+
if not _start_vllm_server(model_id, tensor_parallel, port):
|
|
427
|
+
print(f" Skipping {model_size} (failed to start)")
|
|
428
|
+
continue
|
|
429
|
+
|
|
430
|
+
try:
|
|
431
|
+
# 运行测试
|
|
432
|
+
accuracy, avg_latency = _test_llm_on_challenge(model_id, samples, challenge, port)
|
|
433
|
+
|
|
434
|
+
results["accuracy"].append((model_size, accuracy, avg_latency))
|
|
435
|
+
|
|
436
|
+
if verbose:
|
|
437
|
+
print(f" Accuracy: {accuracy * 100:.1f}%")
|
|
438
|
+
print(f" Avg Latency: {avg_latency:.0f}ms")
|
|
439
|
+
|
|
440
|
+
finally:
|
|
441
|
+
# 停止服务,释放 GPU 内存
|
|
442
|
+
_stop_vllm_server(port)
|
|
443
|
+
import time
|
|
444
|
+
|
|
445
|
+
time.sleep(5) # 等待 GPU 内存释放
|
|
446
|
+
|
|
447
|
+
return results
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
# =============================================================================
|
|
451
|
+
# Main Experiment
|
|
452
|
+
# =============================================================================
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def run_scaling_analysis(
|
|
456
|
+
tool_scaling: bool = True,
|
|
457
|
+
llm_scaling: bool = True,
|
|
458
|
+
max_samples: int = 50,
|
|
459
|
+
verbose: bool = True,
|
|
460
|
+
) -> dict[str, Any]:
|
|
461
|
+
"""
|
|
462
|
+
运行完整的 Scaling 分析。
|
|
463
|
+
"""
|
|
464
|
+
setup_experiment_env(verbose=verbose)
|
|
465
|
+
|
|
466
|
+
print_section_header("Section 5.3.2: Scaling Analysis")
|
|
467
|
+
|
|
468
|
+
all_results = {}
|
|
469
|
+
|
|
470
|
+
if tool_scaling:
|
|
471
|
+
tool_results = run_tool_scaling_experiment(max_samples=max_samples, verbose=verbose)
|
|
472
|
+
all_results["tool_scaling"] = tool_results
|
|
473
|
+
|
|
474
|
+
# 分析结果
|
|
475
|
+
if tool_results:
|
|
476
|
+
print("\n Tool Scaling Summary:")
|
|
477
|
+
for strategy, data in tool_results.items():
|
|
478
|
+
if data:
|
|
479
|
+
# 计算性能下降率
|
|
480
|
+
first_acc = data[0][1]
|
|
481
|
+
last_acc = data[-1][1]
|
|
482
|
+
drop_rate = (first_acc - last_acc) / first_acc if first_acc > 0 else 0
|
|
483
|
+
print(
|
|
484
|
+
f" {strategy}: {first_acc * 100:.1f}% → {last_acc * 100:.1f}% (drop: {drop_rate * 100:.1f}%)"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
if llm_scaling:
|
|
488
|
+
llm_results = run_llm_scaling_experiment(max_samples=max_samples, verbose=verbose)
|
|
489
|
+
all_results["llm_scaling"] = llm_results # type: ignore[assignment]
|
|
490
|
+
|
|
491
|
+
# 保存结果
|
|
492
|
+
output_file = save_results(all_results, "5_3_analysis", "scaling_analysis")
|
|
493
|
+
print(f"\n Results saved to: {output_file}")
|
|
494
|
+
|
|
495
|
+
# 生成图表
|
|
496
|
+
_generate_scaling_figures(all_results)
|
|
497
|
+
|
|
498
|
+
return all_results
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _generate_scaling_figures(results: dict) -> None:
|
|
502
|
+
"""生成 Scaling 分析图表。"""
|
|
503
|
+
try:
|
|
504
|
+
from figure_generator import plot_scaling_curve
|
|
505
|
+
|
|
506
|
+
figures_dir = get_figures_dir()
|
|
507
|
+
|
|
508
|
+
# Tool scaling curve
|
|
509
|
+
if "tool_scaling" in results and results["tool_scaling"]:
|
|
510
|
+
tool_data = {}
|
|
511
|
+
for strategy, data in results["tool_scaling"].items():
|
|
512
|
+
tool_data[strategy.split(".")[-1]] = [(d[0], d[1]) for d in data]
|
|
513
|
+
|
|
514
|
+
plot_scaling_curve(
|
|
515
|
+
tool_data,
|
|
516
|
+
x_label="Number of Candidate Tools",
|
|
517
|
+
y_label="Top-5 Accuracy (%)",
|
|
518
|
+
title="Tool Set Size Scaling",
|
|
519
|
+
output_path=figures_dir / "fig5_analysis_scaling_tool_count.pdf",
|
|
520
|
+
log_x=True,
|
|
521
|
+
)
|
|
522
|
+
print(" Figure saved: fig5_analysis_scaling_tool_count.pdf")
|
|
523
|
+
|
|
524
|
+
# LLM scaling curve
|
|
525
|
+
if "llm_scaling" in results and results["llm_scaling"].get("accuracy"):
|
|
526
|
+
llm_data = {"planning": results["llm_scaling"]["accuracy"]}
|
|
527
|
+
plot_scaling_curve(
|
|
528
|
+
{"planning": [(d[0], d[1]) for d in llm_data["planning"]]},
|
|
529
|
+
x_label="Model Size",
|
|
530
|
+
y_label="Plan Success Rate (%)",
|
|
531
|
+
title="LLM Size Scaling",
|
|
532
|
+
output_path=figures_dir / "fig6_analysis_scaling_llm_size.pdf",
|
|
533
|
+
)
|
|
534
|
+
print(" Figure saved: fig6_analysis_scaling_llm_size.pdf")
|
|
535
|
+
|
|
536
|
+
except Exception as e:
|
|
537
|
+
print(f" Warning: Could not generate figures: {e}")
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def main():
|
|
541
|
+
parser = argparse.ArgumentParser(description="Section 5.3.2: Scaling Analysis")
|
|
542
|
+
parser.add_argument("--tool-scaling", action="store_true", help="Run tool scaling only")
|
|
543
|
+
parser.add_argument("--llm-scaling", action="store_true", help="Run LLM scaling only")
|
|
544
|
+
parser.add_argument("--max-samples", type=int, default=50, help="Maximum samples per test")
|
|
545
|
+
parser.add_argument("--verbose", action="store_true", default=True, help="Verbose output")
|
|
546
|
+
args = parser.parse_args()
|
|
547
|
+
|
|
548
|
+
# 如果没有指定具体类型,运行所有
|
|
549
|
+
run_tool = args.tool_scaling or not (args.tool_scaling or args.llm_scaling)
|
|
550
|
+
run_llm = args.llm_scaling or not (args.tool_scaling or args.llm_scaling)
|
|
551
|
+
|
|
552
|
+
run_scaling_analysis(
|
|
553
|
+
tool_scaling=run_tool,
|
|
554
|
+
llm_scaling=run_llm,
|
|
555
|
+
max_samples=args.max_samples,
|
|
556
|
+
verbose=args.verbose,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
print("\n" + "=" * 70)
|
|
560
|
+
print("📊 Scaling Analysis Complete")
|
|
561
|
+
print("=" * 70)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
if __name__ == "__main__":
|
|
565
|
+
main()
|