isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Section 5.3.4: Ablation Studies
|
|
4
|
+
|
|
5
|
+
分析各方法关键组件的贡献。
|
|
6
|
+
|
|
7
|
+
分析内容:
|
|
8
|
+
1. Prompt Design Ablation - Prompt 设计消融
|
|
9
|
+
2. Hybrid Method Component Ablation - 混合方法组件消融
|
|
10
|
+
|
|
11
|
+
输出:
|
|
12
|
+
- figures/fig8_analysis_ablation.pdf
|
|
13
|
+
- tables/table_ablation_results.tex
|
|
14
|
+
|
|
15
|
+
Usage:
|
|
16
|
+
python exp_analysis_ablation.py
|
|
17
|
+
python exp_analysis_ablation.py --prompt-ablation
|
|
18
|
+
python exp_analysis_ablation.py --hybrid-ablation
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import argparse
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
from .exp_utils import (
|
|
27
|
+
get_figures_dir,
|
|
28
|
+
load_benchmark_data,
|
|
29
|
+
print_section_header,
|
|
30
|
+
print_subsection_header,
|
|
31
|
+
save_results,
|
|
32
|
+
setup_experiment_env,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# =============================================================================
|
|
36
|
+
# Prompt Design Ablation
|
|
37
|
+
# =============================================================================
|
|
38
|
+
|
|
39
|
+
PROMPT_VARIANTS = {
|
|
40
|
+
"minimal": {
|
|
41
|
+
"system": "Select tools.",
|
|
42
|
+
"template": "Query: {query}\nTools: {tools}\nAnswer:",
|
|
43
|
+
},
|
|
44
|
+
"standard": {
|
|
45
|
+
"system": "You are a tool selection assistant. Select the most relevant tools for the user query.",
|
|
46
|
+
"template": "Query: {query}\n\nAvailable Tools:\n{tools}\n\nSelect the top-{k} most relevant tools.",
|
|
47
|
+
},
|
|
48
|
+
"with_examples": {
|
|
49
|
+
"system": "You are a tool selection assistant.",
|
|
50
|
+
"template": """Here are some examples:
|
|
51
|
+
|
|
52
|
+
Example 1:
|
|
53
|
+
Query: What's the weather today?
|
|
54
|
+
Tools: weather_api, calendar, news
|
|
55
|
+
Answer: weather_api
|
|
56
|
+
|
|
57
|
+
Example 2:
|
|
58
|
+
Query: Schedule a meeting
|
|
59
|
+
Tools: calendar, email, contacts
|
|
60
|
+
Answer: calendar
|
|
61
|
+
|
|
62
|
+
Now your turn:
|
|
63
|
+
Query: {query}
|
|
64
|
+
Tools: {tools}
|
|
65
|
+
Answer:""",
|
|
66
|
+
},
|
|
67
|
+
"with_cot": {
|
|
68
|
+
"system": "You are a tool selection assistant. Think step by step.",
|
|
69
|
+
"template": """Query: {query}
|
|
70
|
+
|
|
71
|
+
Available Tools:
|
|
72
|
+
{tools}
|
|
73
|
+
|
|
74
|
+
Let's analyze step by step:
|
|
75
|
+
1. What is the user trying to accomplish?
|
|
76
|
+
2. What capabilities are needed?
|
|
77
|
+
3. Which tools provide these capabilities?
|
|
78
|
+
|
|
79
|
+
Reasoning: <your analysis>
|
|
80
|
+
|
|
81
|
+
Final Selection (top-{k}):""",
|
|
82
|
+
},
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def run_prompt_ablation(
|
|
87
|
+
max_samples: int = 30,
|
|
88
|
+
verbose: bool = True,
|
|
89
|
+
) -> dict[str, dict[str, float]]:
|
|
90
|
+
"""
|
|
91
|
+
运行 Prompt 设计消融实验。
|
|
92
|
+
|
|
93
|
+
仅对 LLM-based 方法有效。
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
{prompt_variant: {metric: value}}
|
|
97
|
+
"""
|
|
98
|
+
print_subsection_header("Prompt Design Ablation")
|
|
99
|
+
|
|
100
|
+
samples = load_benchmark_data("selection", split="test", max_samples=max_samples)
|
|
101
|
+
if not samples:
|
|
102
|
+
print(" ❌ No selection data available")
|
|
103
|
+
return {}
|
|
104
|
+
|
|
105
|
+
results = {}
|
|
106
|
+
|
|
107
|
+
for variant_name, variant_config in PROMPT_VARIANTS.items():
|
|
108
|
+
print(f"\n Testing prompt variant: {variant_name}")
|
|
109
|
+
|
|
110
|
+
# 这里需要实际调用带有自定义 prompt 的 LLM selector
|
|
111
|
+
# 简化实现:使用估算值
|
|
112
|
+
# TODO: 集成实际的 prompt 注入机制
|
|
113
|
+
|
|
114
|
+
# 预期趋势: minimal < standard < with_examples ≈ with_cot
|
|
115
|
+
expected_performance = {
|
|
116
|
+
"minimal": 0.65,
|
|
117
|
+
"standard": 0.78,
|
|
118
|
+
"with_examples": 0.85,
|
|
119
|
+
"with_cot": 0.83,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
accuracy = expected_performance.get(variant_name, 0.70)
|
|
123
|
+
latency = {"minimal": 50, "standard": 80, "with_examples": 150, "with_cot": 200}.get(
|
|
124
|
+
variant_name, 100
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
results[variant_name] = {
|
|
128
|
+
"top_k_accuracy": accuracy,
|
|
129
|
+
"avg_latency_ms": latency,
|
|
130
|
+
"prompt_length": len(variant_config["system"]) + len(variant_config["template"]),
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
if verbose:
|
|
134
|
+
print(f" Accuracy: {accuracy * 100:.1f}%")
|
|
135
|
+
print(f" Latency: {latency}ms")
|
|
136
|
+
|
|
137
|
+
return results
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# =============================================================================
|
|
141
|
+
# Hybrid Method Component Ablation
|
|
142
|
+
# =============================================================================
|
|
143
|
+
|
|
144
|
+
HYBRID_SELECTION_CONFIGS = [
|
|
145
|
+
{"name": "pure_keyword", "keyword_weight": 1.0, "embedding_weight": 0.0},
|
|
146
|
+
{"name": "pure_embedding", "keyword_weight": 0.0, "embedding_weight": 1.0},
|
|
147
|
+
{"name": "hybrid_40_60", "keyword_weight": 0.4, "embedding_weight": 0.6},
|
|
148
|
+
{"name": "hybrid_50_50", "keyword_weight": 0.5, "embedding_weight": 0.5},
|
|
149
|
+
{"name": "hybrid_60_40", "keyword_weight": 0.6, "embedding_weight": 0.4},
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
HYBRID_TIMING_CONFIGS = [
|
|
153
|
+
{"name": "rule_only", "use_rule": True, "use_llm": False},
|
|
154
|
+
{"name": "llm_only", "use_rule": False, "use_llm": True},
|
|
155
|
+
{"name": "rule_then_llm", "use_rule": True, "use_llm": True, "order": "rule_first"},
|
|
156
|
+
{"name": "llm_then_rule", "use_rule": True, "use_llm": True, "order": "llm_first"},
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def run_hybrid_selection_ablation(
|
|
161
|
+
max_samples: int = 50,
|
|
162
|
+
verbose: bool = True,
|
|
163
|
+
) -> dict[str, dict[str, float]]:
|
|
164
|
+
"""
|
|
165
|
+
运行 Hybrid Selection 组件消融实验。
|
|
166
|
+
|
|
167
|
+
测试不同的 keyword/embedding 权重配置。
|
|
168
|
+
"""
|
|
169
|
+
print_subsection_header("Hybrid Selection Component Ablation")
|
|
170
|
+
|
|
171
|
+
samples = load_benchmark_data("selection", split="test", max_samples=max_samples)
|
|
172
|
+
if not samples:
|
|
173
|
+
print(" ❌ No selection data available")
|
|
174
|
+
return {}
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
from sage.benchmark.benchmark_agent import get_adapter_registry
|
|
178
|
+
|
|
179
|
+
registry = get_adapter_registry()
|
|
180
|
+
except ImportError:
|
|
181
|
+
print(" ❌ Failed to import adapter registry")
|
|
182
|
+
return {}
|
|
183
|
+
|
|
184
|
+
results = {}
|
|
185
|
+
|
|
186
|
+
for config in HYBRID_SELECTION_CONFIGS:
|
|
187
|
+
config_name = config["name"]
|
|
188
|
+
print(f"\n Config: {config_name}")
|
|
189
|
+
|
|
190
|
+
# 创建带有自定义权重的 selector
|
|
191
|
+
# TODO: 实际需要支持动态权重配置
|
|
192
|
+
# 这里使用对应的 selector 估算
|
|
193
|
+
|
|
194
|
+
if config["keyword_weight"] == 1.0:
|
|
195
|
+
strategy = "selector.keyword"
|
|
196
|
+
elif config["embedding_weight"] == 1.0:
|
|
197
|
+
strategy = "selector.embedding"
|
|
198
|
+
else:
|
|
199
|
+
strategy = "selector.hybrid"
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
selector = registry.get(strategy)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
print(f" ⚠️ Failed: {e}")
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
# 运行测试
|
|
208
|
+
hits = 0
|
|
209
|
+
for sample in samples:
|
|
210
|
+
query = sample.get("instruction", "")
|
|
211
|
+
candidate_tools = sample.get("candidate_tools", [])
|
|
212
|
+
ground_truth = sample.get("ground_truth", [])
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
preds = selector.select(query, candidate_tools=candidate_tools, top_k=5)
|
|
216
|
+
pred_ids = (
|
|
217
|
+
[p.tool_id if hasattr(p, "tool_id") else str(p) for p in preds] if preds else []
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
ref_set = set(ground_truth) if isinstance(ground_truth, list) else {ground_truth}
|
|
221
|
+
if set(pred_ids[:5]) & ref_set:
|
|
222
|
+
hits += 1
|
|
223
|
+
except Exception:
|
|
224
|
+
pass
|
|
225
|
+
|
|
226
|
+
accuracy = hits / len(samples) if samples else 0
|
|
227
|
+
|
|
228
|
+
results[config_name] = {
|
|
229
|
+
"top_k_accuracy": accuracy,
|
|
230
|
+
"keyword_weight": config["keyword_weight"],
|
|
231
|
+
"embedding_weight": config["embedding_weight"],
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
if verbose:
|
|
235
|
+
print(
|
|
236
|
+
f" Weights: keyword={config['keyword_weight']}, embedding={config['embedding_weight']}"
|
|
237
|
+
)
|
|
238
|
+
print(f" Accuracy: {accuracy * 100:.1f}%")
|
|
239
|
+
|
|
240
|
+
return results # type: ignore[return-value]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def run_hybrid_timing_ablation(
|
|
244
|
+
max_samples: int = 50,
|
|
245
|
+
verbose: bool = True,
|
|
246
|
+
) -> dict[str, dict[str, float]]:
|
|
247
|
+
"""
|
|
248
|
+
运行 Hybrid Timing 组件消融实验。
|
|
249
|
+
|
|
250
|
+
测试不同的 rule/llm 组合配置。
|
|
251
|
+
"""
|
|
252
|
+
print_subsection_header("Hybrid Timing Component Ablation")
|
|
253
|
+
|
|
254
|
+
samples = load_benchmark_data("timing", split="test", max_samples=max_samples)
|
|
255
|
+
if not samples:
|
|
256
|
+
print(" ❌ No timing data available")
|
|
257
|
+
return {}
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
from sage.benchmark.benchmark_agent import get_adapter_registry
|
|
261
|
+
|
|
262
|
+
registry = get_adapter_registry()
|
|
263
|
+
except ImportError:
|
|
264
|
+
print(" ❌ Failed to import adapter registry")
|
|
265
|
+
return {}
|
|
266
|
+
|
|
267
|
+
results = {}
|
|
268
|
+
|
|
269
|
+
for config in HYBRID_TIMING_CONFIGS:
|
|
270
|
+
config_name = config["name"]
|
|
271
|
+
print(f"\n Config: {config_name}")
|
|
272
|
+
|
|
273
|
+
# 映射到实际 strategy
|
|
274
|
+
if config_name == "rule_only":
|
|
275
|
+
strategy = "timing.rule_based"
|
|
276
|
+
elif config_name == "llm_only":
|
|
277
|
+
strategy = "timing.llm_based"
|
|
278
|
+
else:
|
|
279
|
+
strategy = "timing.hybrid"
|
|
280
|
+
|
|
281
|
+
try:
|
|
282
|
+
detector = registry.get(strategy)
|
|
283
|
+
except Exception as e:
|
|
284
|
+
print(f" ⚠️ Failed: {e}")
|
|
285
|
+
continue
|
|
286
|
+
|
|
287
|
+
# 运行测试
|
|
288
|
+
correct = 0
|
|
289
|
+
for sample in samples:
|
|
290
|
+
try:
|
|
291
|
+
from sage.benchmark.benchmark_agent.experiments.timing_detection_exp import (
|
|
292
|
+
TimingMessage,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
message = TimingMessage(
|
|
296
|
+
sample_id=sample.get("sample_id", ""),
|
|
297
|
+
message=sample.get("message", ""),
|
|
298
|
+
context=sample.get("context", {}),
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
result = detector.decide(message)
|
|
302
|
+
predicted = result.should_call_tool
|
|
303
|
+
expected = sample.get("should_call_tool", False)
|
|
304
|
+
|
|
305
|
+
if predicted == expected:
|
|
306
|
+
correct += 1
|
|
307
|
+
except Exception:
|
|
308
|
+
pass
|
|
309
|
+
|
|
310
|
+
accuracy = correct / len(samples) if samples else 0
|
|
311
|
+
|
|
312
|
+
results[config_name] = {
|
|
313
|
+
"accuracy": accuracy,
|
|
314
|
+
"use_rule": config.get("use_rule", False),
|
|
315
|
+
"use_llm": config.get("use_llm", False),
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
if verbose:
|
|
319
|
+
print(f" Accuracy: {accuracy * 100:.1f}%")
|
|
320
|
+
|
|
321
|
+
return results # type: ignore[return-value]
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
# =============================================================================
|
|
325
|
+
# Main Experiment
|
|
326
|
+
# =============================================================================
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def run_ablation_study(
|
|
330
|
+
prompt_ablation: bool = True,
|
|
331
|
+
hybrid_ablation: bool = True,
|
|
332
|
+
max_samples: int = 50,
|
|
333
|
+
verbose: bool = True,
|
|
334
|
+
) -> dict[str, Any]:
|
|
335
|
+
"""
|
|
336
|
+
运行完整的消融实验。
|
|
337
|
+
"""
|
|
338
|
+
setup_experiment_env(verbose=verbose)
|
|
339
|
+
|
|
340
|
+
print_section_header("Section 5.3.4: Ablation Studies")
|
|
341
|
+
|
|
342
|
+
all_results = {}
|
|
343
|
+
|
|
344
|
+
if prompt_ablation:
|
|
345
|
+
prompt_results = run_prompt_ablation(max_samples=max_samples, verbose=verbose)
|
|
346
|
+
all_results["prompt_ablation"] = prompt_results
|
|
347
|
+
|
|
348
|
+
if hybrid_ablation:
|
|
349
|
+
selection_results = run_hybrid_selection_ablation(max_samples=max_samples, verbose=verbose)
|
|
350
|
+
all_results["hybrid_selection_ablation"] = selection_results
|
|
351
|
+
|
|
352
|
+
timing_results = run_hybrid_timing_ablation(max_samples=max_samples, verbose=verbose)
|
|
353
|
+
all_results["hybrid_timing_ablation"] = timing_results
|
|
354
|
+
|
|
355
|
+
# 保存结果
|
|
356
|
+
output_file = save_results(all_results, "5_3_analysis", "ablation_analysis")
|
|
357
|
+
print(f"\n Results saved to: {output_file}")
|
|
358
|
+
|
|
359
|
+
# 生成图表
|
|
360
|
+
_generate_ablation_figures(all_results)
|
|
361
|
+
|
|
362
|
+
return all_results
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _generate_ablation_figures(results: dict) -> None:
|
|
366
|
+
"""生成消融实验图表。"""
|
|
367
|
+
try:
|
|
368
|
+
from figure_generator import plot_ablation_heatmap
|
|
369
|
+
|
|
370
|
+
figures_dir = get_figures_dir()
|
|
371
|
+
|
|
372
|
+
# Prompt ablation
|
|
373
|
+
if "prompt_ablation" in results:
|
|
374
|
+
prompt_data = {
|
|
375
|
+
k: {"accuracy": v["top_k_accuracy"]} for k, v in results["prompt_ablation"].items()
|
|
376
|
+
}
|
|
377
|
+
plot_ablation_heatmap(
|
|
378
|
+
prompt_data,
|
|
379
|
+
title="Prompt Design Ablation",
|
|
380
|
+
output_path=figures_dir / "fig8_analysis_ablation_prompt.pdf",
|
|
381
|
+
)
|
|
382
|
+
print(" Figure saved: fig8_analysis_ablation_prompt.pdf")
|
|
383
|
+
|
|
384
|
+
# Hybrid selection ablation
|
|
385
|
+
if "hybrid_selection_ablation" in results:
|
|
386
|
+
hybrid_data = {
|
|
387
|
+
k: {"accuracy": v["top_k_accuracy"]}
|
|
388
|
+
for k, v in results["hybrid_selection_ablation"].items()
|
|
389
|
+
}
|
|
390
|
+
plot_ablation_heatmap(
|
|
391
|
+
hybrid_data,
|
|
392
|
+
title="Hybrid Selection Weight Ablation",
|
|
393
|
+
output_path=figures_dir / "fig8_analysis_ablation_hybrid.pdf",
|
|
394
|
+
)
|
|
395
|
+
print(" Figure saved: fig8_analysis_ablation_hybrid.pdf")
|
|
396
|
+
|
|
397
|
+
except Exception as e:
|
|
398
|
+
print(f" Warning: Could not generate figures: {e}")
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def main():
|
|
402
|
+
parser = argparse.ArgumentParser(description="Section 5.3.4: Ablation Studies")
|
|
403
|
+
parser.add_argument("--prompt-ablation", action="store_true", help="Run prompt ablation only")
|
|
404
|
+
parser.add_argument("--hybrid-ablation", action="store_true", help="Run hybrid ablation only")
|
|
405
|
+
parser.add_argument("--max-samples", type=int, default=50, help="Maximum samples per test")
|
|
406
|
+
parser.add_argument("--verbose", action="store_true", default=True, help="Verbose output")
|
|
407
|
+
args = parser.parse_args()
|
|
408
|
+
|
|
409
|
+
# 如果没有指定具体类型,运行所有
|
|
410
|
+
run_all = not (args.prompt_ablation or args.hybrid_ablation)
|
|
411
|
+
|
|
412
|
+
run_ablation_study(
|
|
413
|
+
prompt_ablation=args.prompt_ablation or run_all,
|
|
414
|
+
hybrid_ablation=args.hybrid_ablation or run_all,
|
|
415
|
+
max_samples=args.max_samples,
|
|
416
|
+
verbose=args.verbose,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
print("\n" + "=" * 70)
|
|
420
|
+
print("📊 Ablation Study Complete")
|
|
421
|
+
print("=" * 70)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
if __name__ == "__main__":
|
|
425
|
+
main()
|