isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,742 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Method Comparison Framework for Agent Training Experiments
|
|
3
|
+
|
|
4
|
+
Provides infrastructure to compare different training methods:
|
|
5
|
+
- Method A: Baseline (no coreset, no continual learning)
|
|
6
|
+
- Method B: Coreset Selection (loss_topk, diversity, hybrid, random)
|
|
7
|
+
- Method C: Online Continual Learning
|
|
8
|
+
- Method D: Coreset + Continual Learning (combined)
|
|
9
|
+
|
|
10
|
+
Features:
|
|
11
|
+
- Automatic experiment execution
|
|
12
|
+
- Result collection and aggregation
|
|
13
|
+
- Comparison chart generation
|
|
14
|
+
- Statistical analysis
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
import time
|
|
22
|
+
from dataclasses import dataclass, field
|
|
23
|
+
from datetime import datetime
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Literal, Optional
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class MethodConfig:
|
|
32
|
+
"""Configuration for a single training method."""
|
|
33
|
+
|
|
34
|
+
name: str
|
|
35
|
+
description: str
|
|
36
|
+
|
|
37
|
+
# Coreset settings
|
|
38
|
+
use_coreset: bool = False
|
|
39
|
+
coreset_strategy: Literal["loss_topk", "diversity", "hybrid", "random"] = "loss_topk"
|
|
40
|
+
coreset_target_size: Optional[int] = None
|
|
41
|
+
|
|
42
|
+
# Continual learning settings
|
|
43
|
+
use_continual: bool = False
|
|
44
|
+
continual_buffer_size: int = 2048
|
|
45
|
+
continual_replay_ratio: float = 0.25
|
|
46
|
+
|
|
47
|
+
# Training settings
|
|
48
|
+
max_train_samples: Optional[int] = None
|
|
49
|
+
num_epochs: int = 1
|
|
50
|
+
learning_rate: float = 2e-5
|
|
51
|
+
|
|
52
|
+
# Advanced LoRA methods (Task B4)
|
|
53
|
+
use_dora: bool = False # DoRA: Weight-Decomposed LoRA
|
|
54
|
+
use_lora_plus: bool = False # LoRA+: Differentiated learning rates
|
|
55
|
+
lora_plus_lr_ratio: float = 16.0 # B matrix lr = base_lr * ratio
|
|
56
|
+
|
|
57
|
+
# FireAct trajectory fine-tuning (Task B1)
|
|
58
|
+
use_trajectory_collection: bool = False # Enable FireAct-style trajectory collection
|
|
59
|
+
trajectory_min_reward: float = 0.5 # Minimum reward for filtering trajectories
|
|
60
|
+
trajectory_require_success: bool = True # Only use successful trajectories
|
|
61
|
+
trajectory_max_steps: int = 10 # Maximum steps per trajectory
|
|
62
|
+
|
|
63
|
+
# AgentTuning multi-task training (Task B2)
|
|
64
|
+
use_multi_task: bool = False # Enable AgentTuning-style multi-task mixing
|
|
65
|
+
task_weights: Optional[dict[str, float]] = None # Task type weights
|
|
66
|
+
mixing_strategy: Literal["weighted", "balanced", "curriculum"] = "weighted"
|
|
67
|
+
|
|
68
|
+
def to_dict(self) -> dict:
|
|
69
|
+
return {
|
|
70
|
+
"name": self.name,
|
|
71
|
+
"description": self.description,
|
|
72
|
+
"use_coreset": self.use_coreset,
|
|
73
|
+
"coreset_strategy": self.coreset_strategy,
|
|
74
|
+
"coreset_target_size": self.coreset_target_size,
|
|
75
|
+
"use_continual": self.use_continual,
|
|
76
|
+
"continual_buffer_size": self.continual_buffer_size,
|
|
77
|
+
"continual_replay_ratio": self.continual_replay_ratio,
|
|
78
|
+
"max_train_samples": self.max_train_samples,
|
|
79
|
+
"num_epochs": self.num_epochs,
|
|
80
|
+
"learning_rate": self.learning_rate,
|
|
81
|
+
"use_dora": self.use_dora,
|
|
82
|
+
"use_lora_plus": self.use_lora_plus,
|
|
83
|
+
"lora_plus_lr_ratio": self.lora_plus_lr_ratio,
|
|
84
|
+
"use_trajectory_collection": self.use_trajectory_collection,
|
|
85
|
+
"trajectory_min_reward": self.trajectory_min_reward,
|
|
86
|
+
"trajectory_require_success": self.trajectory_require_success,
|
|
87
|
+
"trajectory_max_steps": self.trajectory_max_steps,
|
|
88
|
+
"use_multi_task": self.use_multi_task,
|
|
89
|
+
"task_weights": self.task_weights,
|
|
90
|
+
"mixing_strategy": self.mixing_strategy,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class ExperimentResult:
|
|
96
|
+
"""Result from a single method experiment."""
|
|
97
|
+
|
|
98
|
+
method_name: str
|
|
99
|
+
config: dict
|
|
100
|
+
metrics: dict[str, float]
|
|
101
|
+
training_time_seconds: float
|
|
102
|
+
eval_time_seconds: float
|
|
103
|
+
num_train_samples: int
|
|
104
|
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
105
|
+
|
|
106
|
+
def to_dict(self) -> dict:
|
|
107
|
+
return {
|
|
108
|
+
"method_name": self.method_name,
|
|
109
|
+
"config": self.config,
|
|
110
|
+
"metrics": self.metrics,
|
|
111
|
+
"training_time_seconds": self.training_time_seconds,
|
|
112
|
+
"eval_time_seconds": self.eval_time_seconds,
|
|
113
|
+
"num_train_samples": self.num_train_samples,
|
|
114
|
+
"timestamp": self.timestamp,
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class MethodRegistry:
|
|
119
|
+
"""Registry of predefined training methods."""
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def get_all_methods() -> dict[str, MethodConfig]:
|
|
123
|
+
"""Get all predefined methods for comparison."""
|
|
124
|
+
return {
|
|
125
|
+
"A_baseline": MethodConfig(
|
|
126
|
+
name="A: Baseline",
|
|
127
|
+
description="Standard SFT without coreset or continual learning",
|
|
128
|
+
use_coreset=False,
|
|
129
|
+
use_continual=False,
|
|
130
|
+
),
|
|
131
|
+
"B1_coreset_loss": MethodConfig(
|
|
132
|
+
name="B1: Coreset (Loss Top-K)",
|
|
133
|
+
description="Select samples with highest loss values",
|
|
134
|
+
use_coreset=True,
|
|
135
|
+
coreset_strategy="loss_topk",
|
|
136
|
+
coreset_target_size=1000,
|
|
137
|
+
),
|
|
138
|
+
"B2_coreset_diversity": MethodConfig(
|
|
139
|
+
name="B2: Coreset (Diversity)",
|
|
140
|
+
description="Select diverse samples using feature distance",
|
|
141
|
+
use_coreset=True,
|
|
142
|
+
coreset_strategy="diversity",
|
|
143
|
+
coreset_target_size=1000,
|
|
144
|
+
),
|
|
145
|
+
"B3_coreset_hybrid": MethodConfig(
|
|
146
|
+
name="B3: Coreset (Hybrid)",
|
|
147
|
+
description="60% loss-based + 40% diversity-based selection",
|
|
148
|
+
use_coreset=True,
|
|
149
|
+
coreset_strategy="hybrid",
|
|
150
|
+
coreset_target_size=1000,
|
|
151
|
+
),
|
|
152
|
+
"B4_coreset_random": MethodConfig(
|
|
153
|
+
name="B4: Coreset (Random)",
|
|
154
|
+
description="Random subset selection (control)",
|
|
155
|
+
use_coreset=True,
|
|
156
|
+
coreset_strategy="random",
|
|
157
|
+
coreset_target_size=1000,
|
|
158
|
+
),
|
|
159
|
+
"C_continual": MethodConfig(
|
|
160
|
+
name="C: Continual Learning",
|
|
161
|
+
description="Online continual learning with replay buffer",
|
|
162
|
+
use_coreset=False,
|
|
163
|
+
use_continual=True,
|
|
164
|
+
continual_buffer_size=2048,
|
|
165
|
+
continual_replay_ratio=0.25,
|
|
166
|
+
),
|
|
167
|
+
"D_combined": MethodConfig(
|
|
168
|
+
name="D: Coreset + Continual",
|
|
169
|
+
description="Combined coreset selection and continual learning",
|
|
170
|
+
use_coreset=True,
|
|
171
|
+
coreset_strategy="hybrid",
|
|
172
|
+
coreset_target_size=1500,
|
|
173
|
+
use_continual=True,
|
|
174
|
+
continual_buffer_size=2048,
|
|
175
|
+
continual_replay_ratio=0.20,
|
|
176
|
+
),
|
|
177
|
+
# Agent trajectory fine-tuning methods (Task B1: FireAct)
|
|
178
|
+
"E_fireact": MethodConfig(
|
|
179
|
+
name="E: FireAct",
|
|
180
|
+
description="Agent trajectory fine-tuning (Chen et al., 2023)",
|
|
181
|
+
use_trajectory_collection=True,
|
|
182
|
+
trajectory_min_reward=0.5,
|
|
183
|
+
trajectory_require_success=True,
|
|
184
|
+
trajectory_max_steps=10,
|
|
185
|
+
num_epochs=2,
|
|
186
|
+
),
|
|
187
|
+
"F_fireact_coreset": MethodConfig(
|
|
188
|
+
name="F: FireAct + Coreset",
|
|
189
|
+
description="FireAct trajectory collection with coreset selection",
|
|
190
|
+
use_trajectory_collection=True,
|
|
191
|
+
trajectory_min_reward=0.5,
|
|
192
|
+
trajectory_require_success=True,
|
|
193
|
+
use_coreset=True,
|
|
194
|
+
coreset_strategy="hybrid",
|
|
195
|
+
coreset_target_size=1000,
|
|
196
|
+
num_epochs=2,
|
|
197
|
+
),
|
|
198
|
+
# Advanced LoRA methods (Task B4: DoRA/LoRA+)
|
|
199
|
+
"G_dora": MethodConfig(
|
|
200
|
+
name="G: DoRA",
|
|
201
|
+
description="Weight-Decomposed Low-Rank Adaptation (Liu et al., 2024)",
|
|
202
|
+
use_dora=True,
|
|
203
|
+
),
|
|
204
|
+
"H_lora_plus": MethodConfig(
|
|
205
|
+
name="H: LoRA+",
|
|
206
|
+
description="LoRA with differentiated learning rates (Hayou et al., 2024)",
|
|
207
|
+
use_lora_plus=True,
|
|
208
|
+
lora_plus_lr_ratio=16.0,
|
|
209
|
+
),
|
|
210
|
+
"I_dora_coreset": MethodConfig(
|
|
211
|
+
name="I: DoRA + Coreset",
|
|
212
|
+
description="DoRA combined with hybrid coreset selection",
|
|
213
|
+
use_dora=True,
|
|
214
|
+
use_coreset=True,
|
|
215
|
+
coreset_strategy="hybrid",
|
|
216
|
+
coreset_target_size=1000,
|
|
217
|
+
),
|
|
218
|
+
"J_loraplus_continual": MethodConfig(
|
|
219
|
+
name="J: LoRA+ + Continual",
|
|
220
|
+
description="LoRA+ combined with continual learning",
|
|
221
|
+
use_lora_plus=True,
|
|
222
|
+
lora_plus_lr_ratio=16.0,
|
|
223
|
+
use_continual=True,
|
|
224
|
+
continual_buffer_size=2048,
|
|
225
|
+
continual_replay_ratio=0.25,
|
|
226
|
+
),
|
|
227
|
+
# AgentTuning multi-task training (Task B2)
|
|
228
|
+
"F_agenttuning": MethodConfig(
|
|
229
|
+
name="F: AgentTuning",
|
|
230
|
+
description="Multi-task agent capability tuning (Zeng et al., 2023)",
|
|
231
|
+
use_multi_task=True,
|
|
232
|
+
task_weights={
|
|
233
|
+
"tool_selection": 0.35,
|
|
234
|
+
"planning": 0.30,
|
|
235
|
+
"timing": 0.20,
|
|
236
|
+
"general": 0.15,
|
|
237
|
+
},
|
|
238
|
+
mixing_strategy="weighted",
|
|
239
|
+
num_epochs=2,
|
|
240
|
+
),
|
|
241
|
+
"F2_agenttuning_curriculum": MethodConfig(
|
|
242
|
+
name="F2: AgentTuning (Curriculum)",
|
|
243
|
+
description="AgentTuning with curriculum learning strategy",
|
|
244
|
+
use_multi_task=True,
|
|
245
|
+
task_weights={
|
|
246
|
+
"tool_selection": 0.35,
|
|
247
|
+
"planning": 0.30,
|
|
248
|
+
"timing": 0.20,
|
|
249
|
+
"general": 0.15,
|
|
250
|
+
},
|
|
251
|
+
mixing_strategy="curriculum",
|
|
252
|
+
num_epochs=3,
|
|
253
|
+
),
|
|
254
|
+
"F3_agenttuning_coreset": MethodConfig(
|
|
255
|
+
name="F3: AgentTuning + Coreset",
|
|
256
|
+
description="AgentTuning combined with coreset selection",
|
|
257
|
+
use_multi_task=True,
|
|
258
|
+
task_weights={
|
|
259
|
+
"tool_selection": 0.35,
|
|
260
|
+
"planning": 0.30,
|
|
261
|
+
"timing": 0.20,
|
|
262
|
+
"general": 0.15,
|
|
263
|
+
},
|
|
264
|
+
mixing_strategy="weighted",
|
|
265
|
+
use_coreset=True,
|
|
266
|
+
coreset_strategy="hybrid",
|
|
267
|
+
coreset_target_size=1000,
|
|
268
|
+
num_epochs=2,
|
|
269
|
+
),
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def get_quick_methods() -> dict[str, MethodConfig]:
|
|
274
|
+
"""Get a smaller set of methods for quick testing."""
|
|
275
|
+
return {
|
|
276
|
+
"A_baseline": MethodConfig(
|
|
277
|
+
name="A: Baseline",
|
|
278
|
+
description="Standard SFT",
|
|
279
|
+
max_train_samples=200,
|
|
280
|
+
num_epochs=1,
|
|
281
|
+
),
|
|
282
|
+
"B_coreset": MethodConfig(
|
|
283
|
+
name="B: Coreset (Hybrid)",
|
|
284
|
+
description="Hybrid coreset selection",
|
|
285
|
+
use_coreset=True,
|
|
286
|
+
coreset_strategy="hybrid",
|
|
287
|
+
coreset_target_size=150,
|
|
288
|
+
max_train_samples=200,
|
|
289
|
+
num_epochs=1,
|
|
290
|
+
),
|
|
291
|
+
"C_continual": MethodConfig(
|
|
292
|
+
name="C: Continual",
|
|
293
|
+
description="Continual learning",
|
|
294
|
+
use_continual=True,
|
|
295
|
+
continual_buffer_size=100,
|
|
296
|
+
continual_replay_ratio=0.3,
|
|
297
|
+
max_train_samples=200,
|
|
298
|
+
num_epochs=1,
|
|
299
|
+
),
|
|
300
|
+
"E_fireact": MethodConfig(
|
|
301
|
+
name="E: FireAct",
|
|
302
|
+
description="Agent trajectory fine-tuning",
|
|
303
|
+
use_trajectory_collection=True,
|
|
304
|
+
trajectory_min_reward=0.3,
|
|
305
|
+
trajectory_require_success=False,
|
|
306
|
+
trajectory_max_steps=5,
|
|
307
|
+
max_train_samples=200,
|
|
308
|
+
num_epochs=1,
|
|
309
|
+
),
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class MethodComparisonExperiment:
|
|
314
|
+
"""
|
|
315
|
+
Run comparison experiments across multiple training methods.
|
|
316
|
+
|
|
317
|
+
Example:
|
|
318
|
+
>>> exp = MethodComparisonExperiment(output_dir="./comparison_results")
|
|
319
|
+
>>> exp.run_all_methods()
|
|
320
|
+
>>> exp.generate_comparison_chart()
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
output_dir: str | Path = "./comparison_results",
|
|
326
|
+
base_model: str = "Qwen/Qwen2.5-1.5B-Instruct",
|
|
327
|
+
methods: Optional[dict[str, MethodConfig]] = None,
|
|
328
|
+
dry_run: bool = False,
|
|
329
|
+
):
|
|
330
|
+
self.output_dir = Path(output_dir)
|
|
331
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
332
|
+
self.base_model = base_model
|
|
333
|
+
self.methods = methods or MethodRegistry.get_quick_methods()
|
|
334
|
+
self.dry_run = dry_run
|
|
335
|
+
self.results: list[ExperimentResult] = []
|
|
336
|
+
|
|
337
|
+
def run_all_methods(self, skip_training: bool = False) -> list[ExperimentResult]:
|
|
338
|
+
"""Run experiments for all configured methods."""
|
|
339
|
+
print("=" * 70)
|
|
340
|
+
print("METHOD COMPARISON EXPERIMENT")
|
|
341
|
+
print("=" * 70)
|
|
342
|
+
print(f"Output directory: {self.output_dir}")
|
|
343
|
+
print(f"Base model: {self.base_model}")
|
|
344
|
+
print(f"Methods to compare: {len(self.methods)}")
|
|
345
|
+
print()
|
|
346
|
+
|
|
347
|
+
for method_id, config in self.methods.items():
|
|
348
|
+
print(f"\n{'=' * 50}")
|
|
349
|
+
print(f"Running: {config.name}")
|
|
350
|
+
print(f"{'=' * 50}")
|
|
351
|
+
print(f"Description: {config.description}")
|
|
352
|
+
|
|
353
|
+
if self.dry_run:
|
|
354
|
+
result = self._simulate_run(method_id, config)
|
|
355
|
+
else:
|
|
356
|
+
result = self._run_method(method_id, config, skip_training)
|
|
357
|
+
|
|
358
|
+
self.results.append(result)
|
|
359
|
+
self._save_result(result)
|
|
360
|
+
|
|
361
|
+
print(f"\nResults for {config.name}:")
|
|
362
|
+
for metric, value in result.metrics.items():
|
|
363
|
+
print(f" {metric}: {value:.4f}")
|
|
364
|
+
|
|
365
|
+
# Save aggregated results
|
|
366
|
+
self._save_all_results()
|
|
367
|
+
|
|
368
|
+
return self.results
|
|
369
|
+
|
|
370
|
+
def _run_method(
|
|
371
|
+
self, method_id: str, config: MethodConfig, skip_training: bool
|
|
372
|
+
) -> ExperimentResult:
|
|
373
|
+
"""Run a single method experiment."""
|
|
374
|
+
from sage.benchmark.benchmark_agent import (
|
|
375
|
+
ToolSelectionConfig,
|
|
376
|
+
ToolSelectionExperiment,
|
|
377
|
+
get_adapter_registry,
|
|
378
|
+
)
|
|
379
|
+
from sage.benchmark.benchmark_agent.evaluation import compute_metrics
|
|
380
|
+
from sage.data import DataManager
|
|
381
|
+
|
|
382
|
+
train_time = 0.0
|
|
383
|
+
num_samples = 0
|
|
384
|
+
|
|
385
|
+
if not skip_training:
|
|
386
|
+
# Training phase (if not dry run and training enabled)
|
|
387
|
+
train_start = time.time()
|
|
388
|
+
try:
|
|
389
|
+
from sage.libs.finetune.agent import AgentSFTConfig, AgentSFTTrainer
|
|
390
|
+
|
|
391
|
+
sft_config = AgentSFTConfig(
|
|
392
|
+
base_model=self.base_model,
|
|
393
|
+
train_data="agent_sft:train",
|
|
394
|
+
dev_data="agent_sft:dev",
|
|
395
|
+
max_train_samples=config.max_train_samples,
|
|
396
|
+
num_epochs=config.num_epochs,
|
|
397
|
+
learning_rate=config.learning_rate,
|
|
398
|
+
use_coreset_selection=config.use_coreset,
|
|
399
|
+
coreset_strategy=config.coreset_strategy,
|
|
400
|
+
coreset_target_size=config.coreset_target_size,
|
|
401
|
+
use_online_continual=config.use_continual,
|
|
402
|
+
continual_buffer_size=config.continual_buffer_size,
|
|
403
|
+
continual_replay_ratio=config.continual_replay_ratio,
|
|
404
|
+
# DoRA and LoRA+ settings (Task B4)
|
|
405
|
+
use_dora=config.use_dora,
|
|
406
|
+
use_lora_plus=config.use_lora_plus,
|
|
407
|
+
lora_plus_lr_ratio=config.lora_plus_lr_ratio,
|
|
408
|
+
output_dir=self.output_dir / method_id,
|
|
409
|
+
)
|
|
410
|
+
trainer = AgentSFTTrainer(sft_config)
|
|
411
|
+
trainer.train()
|
|
412
|
+
num_samples = len(trainer._train_samples)
|
|
413
|
+
except Exception as e:
|
|
414
|
+
logger.warning(f"Training failed for {method_id}: {e}")
|
|
415
|
+
num_samples = config.max_train_samples or 4000
|
|
416
|
+
|
|
417
|
+
train_time = time.time() - train_start
|
|
418
|
+
|
|
419
|
+
# Evaluation phase
|
|
420
|
+
eval_start = time.time()
|
|
421
|
+
|
|
422
|
+
dm = DataManager.get_instance()
|
|
423
|
+
registry = get_adapter_registry()
|
|
424
|
+
|
|
425
|
+
eval_config = ToolSelectionConfig(
|
|
426
|
+
experiment="tool_selection",
|
|
427
|
+
profile="quick_eval",
|
|
428
|
+
split="test",
|
|
429
|
+
selector="baseline.keyword",
|
|
430
|
+
top_k=5,
|
|
431
|
+
max_samples=100,
|
|
432
|
+
verbose=False,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
exp = ToolSelectionExperiment(eval_config, data_manager=dm, adapter_registry=registry)
|
|
436
|
+
exp.prepare()
|
|
437
|
+
result = exp.run()
|
|
438
|
+
|
|
439
|
+
metrics = compute_metrics(
|
|
440
|
+
task="tool_selection",
|
|
441
|
+
predictions=result.predictions,
|
|
442
|
+
references=result.references,
|
|
443
|
+
metrics=["top_k_accuracy", "recall_at_k", "precision_at_k", "mrr"],
|
|
444
|
+
k=5,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
eval_time = time.time() - eval_start
|
|
448
|
+
|
|
449
|
+
# Clean up metrics (remove error entries)
|
|
450
|
+
clean_metrics = {
|
|
451
|
+
k: v for k, v in metrics.items() if "_error" not in k and isinstance(v, float)
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
return ExperimentResult(
|
|
455
|
+
method_name=config.name,
|
|
456
|
+
config=config.to_dict(),
|
|
457
|
+
metrics=clean_metrics,
|
|
458
|
+
training_time_seconds=train_time,
|
|
459
|
+
eval_time_seconds=eval_time,
|
|
460
|
+
num_train_samples=num_samples,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
def _simulate_run(self, method_id: str, config: MethodConfig) -> ExperimentResult:
|
|
464
|
+
"""Simulate a run for testing (dry run mode)."""
|
|
465
|
+
import random
|
|
466
|
+
|
|
467
|
+
# Generate simulated metrics with method-specific biases
|
|
468
|
+
base_acc = 0.70
|
|
469
|
+
if config.use_coreset:
|
|
470
|
+
if config.coreset_strategy == "hybrid":
|
|
471
|
+
base_acc += 0.08
|
|
472
|
+
elif config.coreset_strategy == "diversity":
|
|
473
|
+
base_acc += 0.05
|
|
474
|
+
elif config.coreset_strategy == "loss_topk":
|
|
475
|
+
base_acc += 0.06
|
|
476
|
+
if config.use_continual:
|
|
477
|
+
base_acc += 0.04
|
|
478
|
+
|
|
479
|
+
noise = random.uniform(-0.03, 0.03)
|
|
480
|
+
|
|
481
|
+
metrics = {
|
|
482
|
+
"top_k_accuracy": min(base_acc + noise, 0.95),
|
|
483
|
+
"recall_at_k": min((base_acc + noise) * 0.7, 0.85),
|
|
484
|
+
"precision_at_k": min((base_acc + noise) * 0.4, 0.50),
|
|
485
|
+
"mrr": min((base_acc + noise) * 0.6, 0.75),
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
return ExperimentResult(
|
|
489
|
+
method_name=config.name,
|
|
490
|
+
config=config.to_dict(),
|
|
491
|
+
metrics=metrics,
|
|
492
|
+
training_time_seconds=random.uniform(100, 500),
|
|
493
|
+
eval_time_seconds=random.uniform(10, 30),
|
|
494
|
+
num_train_samples=config.max_train_samples or 4000,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
def _save_result(self, result: ExperimentResult):
|
|
498
|
+
"""Save individual result to JSON."""
|
|
499
|
+
result_path = (
|
|
500
|
+
self.output_dir / f"{result.method_name.replace(' ', '_').replace(':', '')}.json"
|
|
501
|
+
)
|
|
502
|
+
with open(result_path, "w") as f:
|
|
503
|
+
json.dump(result.to_dict(), f, indent=2, ensure_ascii=False)
|
|
504
|
+
|
|
505
|
+
def _save_all_results(self):
|
|
506
|
+
"""Save all results to a single JSON file."""
|
|
507
|
+
all_results_path = self.output_dir / "all_results.json"
|
|
508
|
+
with open(all_results_path, "w") as f:
|
|
509
|
+
json.dump([r.to_dict() for r in self.results], f, indent=2, ensure_ascii=False)
|
|
510
|
+
print(f"\nAll results saved to: {all_results_path}")
|
|
511
|
+
|
|
512
|
+
def generate_comparison_chart(
|
|
513
|
+
self,
|
|
514
|
+
output_file: Optional[str] = None,
|
|
515
|
+
show_plot: bool = False,
|
|
516
|
+
) -> Path:
|
|
517
|
+
"""Generate comparison charts from experiment results."""
|
|
518
|
+
if not self.results:
|
|
519
|
+
raise ValueError("No results to plot. Run experiments first.")
|
|
520
|
+
|
|
521
|
+
output_file = output_file or str(self.output_dir / "comparison_chart.png")
|
|
522
|
+
|
|
523
|
+
try:
|
|
524
|
+
import matplotlib.pyplot as plt
|
|
525
|
+
import numpy as np
|
|
526
|
+
except ImportError:
|
|
527
|
+
logger.warning("matplotlib not installed. Generating text report instead.")
|
|
528
|
+
return self._generate_text_report()
|
|
529
|
+
|
|
530
|
+
# Prepare data
|
|
531
|
+
methods = [r.method_name for r in self.results]
|
|
532
|
+
metrics = list(self.results[0].metrics.keys())
|
|
533
|
+
|
|
534
|
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
|
535
|
+
fig.suptitle(
|
|
536
|
+
"Agent Training Method Comparison\nTarget: 95%+ Tool Planning Accuracy", fontsize=14
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# Color palette
|
|
540
|
+
cmap = plt.colormaps.get_cmap("Set2")
|
|
541
|
+
colors = [cmap(i / len(methods)) for i in range(len(methods))]
|
|
542
|
+
|
|
543
|
+
# 1. Bar chart - All metrics comparison
|
|
544
|
+
ax1 = axes[0, 0]
|
|
545
|
+
x = np.arange(len(metrics))
|
|
546
|
+
width = 0.8 / len(methods)
|
|
547
|
+
|
|
548
|
+
for i, result in enumerate(self.results):
|
|
549
|
+
values = [result.metrics.get(m, 0) for m in metrics]
|
|
550
|
+
ax1.bar(x + i * width, values, width, label=result.method_name, color=colors[i])
|
|
551
|
+
|
|
552
|
+
ax1.set_xlabel("Metrics")
|
|
553
|
+
ax1.set_ylabel("Score")
|
|
554
|
+
ax1.set_title("Performance Comparison by Metric")
|
|
555
|
+
ax1.set_xticks(x + width * (len(methods) - 1) / 2)
|
|
556
|
+
ax1.set_xticklabels([m.replace("_", " ").title() for m in metrics], rotation=15)
|
|
557
|
+
ax1.legend(loc="upper right", fontsize=8)
|
|
558
|
+
ax1.axhline(y=0.95, color="red", linestyle="--", alpha=0.7, label="Target (95%)")
|
|
559
|
+
ax1.set_ylim(0, 1.0)
|
|
560
|
+
ax1.grid(axis="y", alpha=0.3)
|
|
561
|
+
|
|
562
|
+
# 2. Radar chart - Method profiles
|
|
563
|
+
ax2 = axes[0, 1]
|
|
564
|
+
angles: list[float] = list(
|
|
565
|
+
np.linspace(0, 2 * np.pi, len(metrics), endpoint=False).astype(float)
|
|
566
|
+
)
|
|
567
|
+
angles += angles[:1] # Close the polygon
|
|
568
|
+
|
|
569
|
+
for i, result in enumerate(self.results):
|
|
570
|
+
radar_values: list[float] = [float(result.metrics.get(m, 0)) for m in metrics]
|
|
571
|
+
radar_values += radar_values[:1]
|
|
572
|
+
ax2.plot(
|
|
573
|
+
angles, radar_values, "o-", linewidth=2, label=result.method_name, color=colors[i]
|
|
574
|
+
)
|
|
575
|
+
ax2.fill(angles, radar_values, alpha=0.1, color=colors[i])
|
|
576
|
+
|
|
577
|
+
ax2.set_xticks(angles[:-1])
|
|
578
|
+
ax2.set_xticklabels([m.replace("_", " ").title() for m in metrics], fontsize=8)
|
|
579
|
+
ax2.set_title("Method Performance Profile (Radar)")
|
|
580
|
+
ax2.legend(loc="upper right", fontsize=7)
|
|
581
|
+
ax2.set_ylim(0, 1.0)
|
|
582
|
+
|
|
583
|
+
# 3. Training efficiency
|
|
584
|
+
ax3 = axes[1, 0]
|
|
585
|
+
train_times = [r.training_time_seconds / 60 for r in self.results] # Convert to minutes
|
|
586
|
+
top_k_accs = [r.metrics.get("top_k_accuracy", 0) for r in self.results]
|
|
587
|
+
|
|
588
|
+
ax3.scatter(train_times, top_k_accs, c=range(len(methods)), cmap="Set2", s=200)
|
|
589
|
+
for idx, (tx, ty, name) in enumerate(zip(train_times, top_k_accs, methods)):
|
|
590
|
+
ax3.annotate(name, (tx, ty), textcoords="offset points", xytext=(5, 5), fontsize=8)
|
|
591
|
+
|
|
592
|
+
ax3.set_xlabel("Training Time (minutes)")
|
|
593
|
+
ax3.set_ylabel("Top-K Accuracy")
|
|
594
|
+
ax3.set_title("Training Efficiency")
|
|
595
|
+
ax3.axhline(y=0.95, color="red", linestyle="--", alpha=0.7)
|
|
596
|
+
ax3.grid(alpha=0.3)
|
|
597
|
+
|
|
598
|
+
# 4. Summary table
|
|
599
|
+
ax4 = axes[1, 1]
|
|
600
|
+
ax4.axis("off")
|
|
601
|
+
|
|
602
|
+
table_data = []
|
|
603
|
+
headers = ["Method", "Top-K Acc", "Recall@K", "MRR", "Train Time"]
|
|
604
|
+
|
|
605
|
+
for r in self.results:
|
|
606
|
+
row = [
|
|
607
|
+
r.method_name[:20],
|
|
608
|
+
f"{r.metrics.get('top_k_accuracy', 0):.1%}",
|
|
609
|
+
f"{r.metrics.get('recall_at_k', 0):.1%}",
|
|
610
|
+
f"{r.metrics.get('mrr', 0):.1%}",
|
|
611
|
+
f"{r.training_time_seconds / 60:.1f}m",
|
|
612
|
+
]
|
|
613
|
+
table_data.append(row)
|
|
614
|
+
|
|
615
|
+
table = ax4.table(
|
|
616
|
+
cellText=table_data,
|
|
617
|
+
colLabels=headers,
|
|
618
|
+
cellLoc="center",
|
|
619
|
+
loc="center",
|
|
620
|
+
colColours=["lightblue"] * len(headers),
|
|
621
|
+
)
|
|
622
|
+
table.auto_set_font_size(False)
|
|
623
|
+
table.set_fontsize(9)
|
|
624
|
+
table.scale(1.2, 1.5)
|
|
625
|
+
ax4.set_title("Results Summary", pad=20)
|
|
626
|
+
|
|
627
|
+
plt.tight_layout()
|
|
628
|
+
plt.savefig(output_file, dpi=150, bbox_inches="tight")
|
|
629
|
+
print(f"\nChart saved to: {output_file}")
|
|
630
|
+
|
|
631
|
+
if show_plot:
|
|
632
|
+
plt.show()
|
|
633
|
+
|
|
634
|
+
plt.close()
|
|
635
|
+
|
|
636
|
+
return Path(output_file)
|
|
637
|
+
|
|
638
|
+
def _generate_text_report(self) -> Path:
|
|
639
|
+
"""Generate a text-based report when matplotlib is not available."""
|
|
640
|
+
report_path = self.output_dir / "comparison_report.txt"
|
|
641
|
+
|
|
642
|
+
lines = [
|
|
643
|
+
"=" * 70,
|
|
644
|
+
"METHOD COMPARISON REPORT",
|
|
645
|
+
"=" * 70,
|
|
646
|
+
f"Generated: {datetime.now().isoformat()}",
|
|
647
|
+
"",
|
|
648
|
+
"-" * 70,
|
|
649
|
+
"RESULTS SUMMARY",
|
|
650
|
+
"-" * 70,
|
|
651
|
+
"",
|
|
652
|
+
]
|
|
653
|
+
|
|
654
|
+
# Find best method for each metric
|
|
655
|
+
metrics = list(self.results[0].metrics.keys()) if self.results else []
|
|
656
|
+
best_per_metric = {}
|
|
657
|
+
for metric in metrics:
|
|
658
|
+
best_result = max(self.results, key=lambda r: r.metrics.get(metric, 0))
|
|
659
|
+
best_per_metric[metric] = (best_result.method_name, best_result.metrics.get(metric, 0))
|
|
660
|
+
|
|
661
|
+
for result in self.results:
|
|
662
|
+
lines.append(f"\n{result.method_name}")
|
|
663
|
+
lines.append("-" * 40)
|
|
664
|
+
for metric, value in result.metrics.items():
|
|
665
|
+
is_best = best_per_metric.get(metric, ("", 0))[0] == result.method_name
|
|
666
|
+
star = " ★" if is_best else ""
|
|
667
|
+
lines.append(f" {metric}: {value:.4f} ({value * 100:.1f}%){star}")
|
|
668
|
+
lines.append(f" Training time: {result.training_time_seconds / 60:.1f} min")
|
|
669
|
+
lines.append(f" Train samples: {result.num_train_samples}")
|
|
670
|
+
|
|
671
|
+
lines.extend(
|
|
672
|
+
[
|
|
673
|
+
"",
|
|
674
|
+
"-" * 70,
|
|
675
|
+
"BEST PERFORMERS",
|
|
676
|
+
"-" * 70,
|
|
677
|
+
]
|
|
678
|
+
)
|
|
679
|
+
for metric, (method, value) in best_per_metric.items():
|
|
680
|
+
lines.append(f" {metric}: {method} ({value * 100:.1f}%)")
|
|
681
|
+
|
|
682
|
+
lines.extend(
|
|
683
|
+
[
|
|
684
|
+
"",
|
|
685
|
+
"-" * 70,
|
|
686
|
+
"TARGET: 95% accuracy for 难题4",
|
|
687
|
+
"-" * 70,
|
|
688
|
+
]
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
with open(report_path, "w") as f:
|
|
692
|
+
f.write("\n".join(lines))
|
|
693
|
+
|
|
694
|
+
print(f"\nText report saved to: {report_path}")
|
|
695
|
+
return report_path
|
|
696
|
+
|
|
697
|
+
def load_results(self, results_file: Optional[str] = None) -> list[ExperimentResult]:
|
|
698
|
+
"""Load results from a previous run."""
|
|
699
|
+
results_file = results_file or str(self.output_dir / "all_results.json")
|
|
700
|
+
|
|
701
|
+
with open(results_file) as f:
|
|
702
|
+
data = json.load(f)
|
|
703
|
+
|
|
704
|
+
self.results = [
|
|
705
|
+
ExperimentResult(
|
|
706
|
+
method_name=r["method_name"],
|
|
707
|
+
config=r["config"],
|
|
708
|
+
metrics=r["metrics"],
|
|
709
|
+
training_time_seconds=r["training_time_seconds"],
|
|
710
|
+
eval_time_seconds=r["eval_time_seconds"],
|
|
711
|
+
num_train_samples=r["num_train_samples"],
|
|
712
|
+
timestamp=r.get("timestamp", ""),
|
|
713
|
+
)
|
|
714
|
+
for r in data
|
|
715
|
+
]
|
|
716
|
+
|
|
717
|
+
return self.results
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def run_quick_comparison(output_dir: str = "./comparison_results", dry_run: bool = True):
|
|
721
|
+
"""Quick comparison with simulated results for testing."""
|
|
722
|
+
exp = MethodComparisonExperiment(
|
|
723
|
+
output_dir=output_dir,
|
|
724
|
+
methods=MethodRegistry.get_quick_methods(),
|
|
725
|
+
dry_run=dry_run,
|
|
726
|
+
)
|
|
727
|
+
exp.run_all_methods()
|
|
728
|
+
return exp.generate_comparison_chart()
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def run_full_comparison(
|
|
732
|
+
output_dir: str = "./comparison_results", base_model: str = "Qwen/Qwen2.5-7B-Instruct"
|
|
733
|
+
):
|
|
734
|
+
"""Full comparison with actual training (requires GPU)."""
|
|
735
|
+
exp = MethodComparisonExperiment(
|
|
736
|
+
output_dir=output_dir,
|
|
737
|
+
base_model=base_model,
|
|
738
|
+
methods=MethodRegistry.get_all_methods(),
|
|
739
|
+
dry_run=False,
|
|
740
|
+
)
|
|
741
|
+
exp.run_all_methods()
|
|
742
|
+
return exp.generate_comparison_chart()
|