isage-benchmark-agent 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
  2. isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
  3. isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
  5. isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  6. isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
  7. sage/__init__.py +0 -0
  8. sage/benchmark/__init__.py +0 -0
  9. sage/benchmark/benchmark_agent/__init__.py +108 -0
  10. sage/benchmark/benchmark_agent/__main__.py +177 -0
  11. sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
  12. sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
  13. sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
  14. sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
  15. sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
  16. sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
  17. sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
  18. sage/benchmark/benchmark_agent/data_paths.py +332 -0
  19. sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
  20. sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
  21. sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
  22. sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
  23. sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
  24. sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
  25. sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
  26. sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
  27. sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
  28. sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
  29. sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
  30. sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
  31. sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
  32. sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
  33. sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
  34. sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
  35. sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
  36. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
  37. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
  38. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
  39. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
  40. sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
  41. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
  42. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
  43. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
  44. sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
  45. sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
  46. sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
  47. sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
  48. sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
  49. sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
  50. sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
  51. sage/benchmark/benchmark_agent/tools_loader.py +212 -0
@@ -0,0 +1,742 @@
1
+ """
2
+ Method Comparison Framework for Agent Training Experiments
3
+
4
+ Provides infrastructure to compare different training methods:
5
+ - Method A: Baseline (no coreset, no continual learning)
6
+ - Method B: Coreset Selection (loss_topk, diversity, hybrid, random)
7
+ - Method C: Online Continual Learning
8
+ - Method D: Coreset + Continual Learning (combined)
9
+
10
+ Features:
11
+ - Automatic experiment execution
12
+ - Result collection and aggregation
13
+ - Comparison chart generation
14
+ - Statistical analysis
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import logging
21
+ import time
22
+ from dataclasses import dataclass, field
23
+ from datetime import datetime
24
+ from pathlib import Path
25
+ from typing import Literal, Optional
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class MethodConfig:
32
+ """Configuration for a single training method."""
33
+
34
+ name: str
35
+ description: str
36
+
37
+ # Coreset settings
38
+ use_coreset: bool = False
39
+ coreset_strategy: Literal["loss_topk", "diversity", "hybrid", "random"] = "loss_topk"
40
+ coreset_target_size: Optional[int] = None
41
+
42
+ # Continual learning settings
43
+ use_continual: bool = False
44
+ continual_buffer_size: int = 2048
45
+ continual_replay_ratio: float = 0.25
46
+
47
+ # Training settings
48
+ max_train_samples: Optional[int] = None
49
+ num_epochs: int = 1
50
+ learning_rate: float = 2e-5
51
+
52
+ # Advanced LoRA methods (Task B4)
53
+ use_dora: bool = False # DoRA: Weight-Decomposed LoRA
54
+ use_lora_plus: bool = False # LoRA+: Differentiated learning rates
55
+ lora_plus_lr_ratio: float = 16.0 # B matrix lr = base_lr * ratio
56
+
57
+ # FireAct trajectory fine-tuning (Task B1)
58
+ use_trajectory_collection: bool = False # Enable FireAct-style trajectory collection
59
+ trajectory_min_reward: float = 0.5 # Minimum reward for filtering trajectories
60
+ trajectory_require_success: bool = True # Only use successful trajectories
61
+ trajectory_max_steps: int = 10 # Maximum steps per trajectory
62
+
63
+ # AgentTuning multi-task training (Task B2)
64
+ use_multi_task: bool = False # Enable AgentTuning-style multi-task mixing
65
+ task_weights: Optional[dict[str, float]] = None # Task type weights
66
+ mixing_strategy: Literal["weighted", "balanced", "curriculum"] = "weighted"
67
+
68
+ def to_dict(self) -> dict:
69
+ return {
70
+ "name": self.name,
71
+ "description": self.description,
72
+ "use_coreset": self.use_coreset,
73
+ "coreset_strategy": self.coreset_strategy,
74
+ "coreset_target_size": self.coreset_target_size,
75
+ "use_continual": self.use_continual,
76
+ "continual_buffer_size": self.continual_buffer_size,
77
+ "continual_replay_ratio": self.continual_replay_ratio,
78
+ "max_train_samples": self.max_train_samples,
79
+ "num_epochs": self.num_epochs,
80
+ "learning_rate": self.learning_rate,
81
+ "use_dora": self.use_dora,
82
+ "use_lora_plus": self.use_lora_plus,
83
+ "lora_plus_lr_ratio": self.lora_plus_lr_ratio,
84
+ "use_trajectory_collection": self.use_trajectory_collection,
85
+ "trajectory_min_reward": self.trajectory_min_reward,
86
+ "trajectory_require_success": self.trajectory_require_success,
87
+ "trajectory_max_steps": self.trajectory_max_steps,
88
+ "use_multi_task": self.use_multi_task,
89
+ "task_weights": self.task_weights,
90
+ "mixing_strategy": self.mixing_strategy,
91
+ }
92
+
93
+
94
+ @dataclass
95
+ class ExperimentResult:
96
+ """Result from a single method experiment."""
97
+
98
+ method_name: str
99
+ config: dict
100
+ metrics: dict[str, float]
101
+ training_time_seconds: float
102
+ eval_time_seconds: float
103
+ num_train_samples: int
104
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
105
+
106
+ def to_dict(self) -> dict:
107
+ return {
108
+ "method_name": self.method_name,
109
+ "config": self.config,
110
+ "metrics": self.metrics,
111
+ "training_time_seconds": self.training_time_seconds,
112
+ "eval_time_seconds": self.eval_time_seconds,
113
+ "num_train_samples": self.num_train_samples,
114
+ "timestamp": self.timestamp,
115
+ }
116
+
117
+
118
+ class MethodRegistry:
119
+ """Registry of predefined training methods."""
120
+
121
+ @staticmethod
122
+ def get_all_methods() -> dict[str, MethodConfig]:
123
+ """Get all predefined methods for comparison."""
124
+ return {
125
+ "A_baseline": MethodConfig(
126
+ name="A: Baseline",
127
+ description="Standard SFT without coreset or continual learning",
128
+ use_coreset=False,
129
+ use_continual=False,
130
+ ),
131
+ "B1_coreset_loss": MethodConfig(
132
+ name="B1: Coreset (Loss Top-K)",
133
+ description="Select samples with highest loss values",
134
+ use_coreset=True,
135
+ coreset_strategy="loss_topk",
136
+ coreset_target_size=1000,
137
+ ),
138
+ "B2_coreset_diversity": MethodConfig(
139
+ name="B2: Coreset (Diversity)",
140
+ description="Select diverse samples using feature distance",
141
+ use_coreset=True,
142
+ coreset_strategy="diversity",
143
+ coreset_target_size=1000,
144
+ ),
145
+ "B3_coreset_hybrid": MethodConfig(
146
+ name="B3: Coreset (Hybrid)",
147
+ description="60% loss-based + 40% diversity-based selection",
148
+ use_coreset=True,
149
+ coreset_strategy="hybrid",
150
+ coreset_target_size=1000,
151
+ ),
152
+ "B4_coreset_random": MethodConfig(
153
+ name="B4: Coreset (Random)",
154
+ description="Random subset selection (control)",
155
+ use_coreset=True,
156
+ coreset_strategy="random",
157
+ coreset_target_size=1000,
158
+ ),
159
+ "C_continual": MethodConfig(
160
+ name="C: Continual Learning",
161
+ description="Online continual learning with replay buffer",
162
+ use_coreset=False,
163
+ use_continual=True,
164
+ continual_buffer_size=2048,
165
+ continual_replay_ratio=0.25,
166
+ ),
167
+ "D_combined": MethodConfig(
168
+ name="D: Coreset + Continual",
169
+ description="Combined coreset selection and continual learning",
170
+ use_coreset=True,
171
+ coreset_strategy="hybrid",
172
+ coreset_target_size=1500,
173
+ use_continual=True,
174
+ continual_buffer_size=2048,
175
+ continual_replay_ratio=0.20,
176
+ ),
177
+ # Agent trajectory fine-tuning methods (Task B1: FireAct)
178
+ "E_fireact": MethodConfig(
179
+ name="E: FireAct",
180
+ description="Agent trajectory fine-tuning (Chen et al., 2023)",
181
+ use_trajectory_collection=True,
182
+ trajectory_min_reward=0.5,
183
+ trajectory_require_success=True,
184
+ trajectory_max_steps=10,
185
+ num_epochs=2,
186
+ ),
187
+ "F_fireact_coreset": MethodConfig(
188
+ name="F: FireAct + Coreset",
189
+ description="FireAct trajectory collection with coreset selection",
190
+ use_trajectory_collection=True,
191
+ trajectory_min_reward=0.5,
192
+ trajectory_require_success=True,
193
+ use_coreset=True,
194
+ coreset_strategy="hybrid",
195
+ coreset_target_size=1000,
196
+ num_epochs=2,
197
+ ),
198
+ # Advanced LoRA methods (Task B4: DoRA/LoRA+)
199
+ "G_dora": MethodConfig(
200
+ name="G: DoRA",
201
+ description="Weight-Decomposed Low-Rank Adaptation (Liu et al., 2024)",
202
+ use_dora=True,
203
+ ),
204
+ "H_lora_plus": MethodConfig(
205
+ name="H: LoRA+",
206
+ description="LoRA with differentiated learning rates (Hayou et al., 2024)",
207
+ use_lora_plus=True,
208
+ lora_plus_lr_ratio=16.0,
209
+ ),
210
+ "I_dora_coreset": MethodConfig(
211
+ name="I: DoRA + Coreset",
212
+ description="DoRA combined with hybrid coreset selection",
213
+ use_dora=True,
214
+ use_coreset=True,
215
+ coreset_strategy="hybrid",
216
+ coreset_target_size=1000,
217
+ ),
218
+ "J_loraplus_continual": MethodConfig(
219
+ name="J: LoRA+ + Continual",
220
+ description="LoRA+ combined with continual learning",
221
+ use_lora_plus=True,
222
+ lora_plus_lr_ratio=16.0,
223
+ use_continual=True,
224
+ continual_buffer_size=2048,
225
+ continual_replay_ratio=0.25,
226
+ ),
227
+ # AgentTuning multi-task training (Task B2)
228
+ "F_agenttuning": MethodConfig(
229
+ name="F: AgentTuning",
230
+ description="Multi-task agent capability tuning (Zeng et al., 2023)",
231
+ use_multi_task=True,
232
+ task_weights={
233
+ "tool_selection": 0.35,
234
+ "planning": 0.30,
235
+ "timing": 0.20,
236
+ "general": 0.15,
237
+ },
238
+ mixing_strategy="weighted",
239
+ num_epochs=2,
240
+ ),
241
+ "F2_agenttuning_curriculum": MethodConfig(
242
+ name="F2: AgentTuning (Curriculum)",
243
+ description="AgentTuning with curriculum learning strategy",
244
+ use_multi_task=True,
245
+ task_weights={
246
+ "tool_selection": 0.35,
247
+ "planning": 0.30,
248
+ "timing": 0.20,
249
+ "general": 0.15,
250
+ },
251
+ mixing_strategy="curriculum",
252
+ num_epochs=3,
253
+ ),
254
+ "F3_agenttuning_coreset": MethodConfig(
255
+ name="F3: AgentTuning + Coreset",
256
+ description="AgentTuning combined with coreset selection",
257
+ use_multi_task=True,
258
+ task_weights={
259
+ "tool_selection": 0.35,
260
+ "planning": 0.30,
261
+ "timing": 0.20,
262
+ "general": 0.15,
263
+ },
264
+ mixing_strategy="weighted",
265
+ use_coreset=True,
266
+ coreset_strategy="hybrid",
267
+ coreset_target_size=1000,
268
+ num_epochs=2,
269
+ ),
270
+ }
271
+
272
+ @staticmethod
273
+ def get_quick_methods() -> dict[str, MethodConfig]:
274
+ """Get a smaller set of methods for quick testing."""
275
+ return {
276
+ "A_baseline": MethodConfig(
277
+ name="A: Baseline",
278
+ description="Standard SFT",
279
+ max_train_samples=200,
280
+ num_epochs=1,
281
+ ),
282
+ "B_coreset": MethodConfig(
283
+ name="B: Coreset (Hybrid)",
284
+ description="Hybrid coreset selection",
285
+ use_coreset=True,
286
+ coreset_strategy="hybrid",
287
+ coreset_target_size=150,
288
+ max_train_samples=200,
289
+ num_epochs=1,
290
+ ),
291
+ "C_continual": MethodConfig(
292
+ name="C: Continual",
293
+ description="Continual learning",
294
+ use_continual=True,
295
+ continual_buffer_size=100,
296
+ continual_replay_ratio=0.3,
297
+ max_train_samples=200,
298
+ num_epochs=1,
299
+ ),
300
+ "E_fireact": MethodConfig(
301
+ name="E: FireAct",
302
+ description="Agent trajectory fine-tuning",
303
+ use_trajectory_collection=True,
304
+ trajectory_min_reward=0.3,
305
+ trajectory_require_success=False,
306
+ trajectory_max_steps=5,
307
+ max_train_samples=200,
308
+ num_epochs=1,
309
+ ),
310
+ }
311
+
312
+
313
+ class MethodComparisonExperiment:
314
+ """
315
+ Run comparison experiments across multiple training methods.
316
+
317
+ Example:
318
+ >>> exp = MethodComparisonExperiment(output_dir="./comparison_results")
319
+ >>> exp.run_all_methods()
320
+ >>> exp.generate_comparison_chart()
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ output_dir: str | Path = "./comparison_results",
326
+ base_model: str = "Qwen/Qwen2.5-1.5B-Instruct",
327
+ methods: Optional[dict[str, MethodConfig]] = None,
328
+ dry_run: bool = False,
329
+ ):
330
+ self.output_dir = Path(output_dir)
331
+ self.output_dir.mkdir(parents=True, exist_ok=True)
332
+ self.base_model = base_model
333
+ self.methods = methods or MethodRegistry.get_quick_methods()
334
+ self.dry_run = dry_run
335
+ self.results: list[ExperimentResult] = []
336
+
337
+ def run_all_methods(self, skip_training: bool = False) -> list[ExperimentResult]:
338
+ """Run experiments for all configured methods."""
339
+ print("=" * 70)
340
+ print("METHOD COMPARISON EXPERIMENT")
341
+ print("=" * 70)
342
+ print(f"Output directory: {self.output_dir}")
343
+ print(f"Base model: {self.base_model}")
344
+ print(f"Methods to compare: {len(self.methods)}")
345
+ print()
346
+
347
+ for method_id, config in self.methods.items():
348
+ print(f"\n{'=' * 50}")
349
+ print(f"Running: {config.name}")
350
+ print(f"{'=' * 50}")
351
+ print(f"Description: {config.description}")
352
+
353
+ if self.dry_run:
354
+ result = self._simulate_run(method_id, config)
355
+ else:
356
+ result = self._run_method(method_id, config, skip_training)
357
+
358
+ self.results.append(result)
359
+ self._save_result(result)
360
+
361
+ print(f"\nResults for {config.name}:")
362
+ for metric, value in result.metrics.items():
363
+ print(f" {metric}: {value:.4f}")
364
+
365
+ # Save aggregated results
366
+ self._save_all_results()
367
+
368
+ return self.results
369
+
370
+ def _run_method(
371
+ self, method_id: str, config: MethodConfig, skip_training: bool
372
+ ) -> ExperimentResult:
373
+ """Run a single method experiment."""
374
+ from sage.benchmark.benchmark_agent import (
375
+ ToolSelectionConfig,
376
+ ToolSelectionExperiment,
377
+ get_adapter_registry,
378
+ )
379
+ from sage.benchmark.benchmark_agent.evaluation import compute_metrics
380
+ from sage.data import DataManager
381
+
382
+ train_time = 0.0
383
+ num_samples = 0
384
+
385
+ if not skip_training:
386
+ # Training phase (if not dry run and training enabled)
387
+ train_start = time.time()
388
+ try:
389
+ from sage.libs.finetune.agent import AgentSFTConfig, AgentSFTTrainer
390
+
391
+ sft_config = AgentSFTConfig(
392
+ base_model=self.base_model,
393
+ train_data="agent_sft:train",
394
+ dev_data="agent_sft:dev",
395
+ max_train_samples=config.max_train_samples,
396
+ num_epochs=config.num_epochs,
397
+ learning_rate=config.learning_rate,
398
+ use_coreset_selection=config.use_coreset,
399
+ coreset_strategy=config.coreset_strategy,
400
+ coreset_target_size=config.coreset_target_size,
401
+ use_online_continual=config.use_continual,
402
+ continual_buffer_size=config.continual_buffer_size,
403
+ continual_replay_ratio=config.continual_replay_ratio,
404
+ # DoRA and LoRA+ settings (Task B4)
405
+ use_dora=config.use_dora,
406
+ use_lora_plus=config.use_lora_plus,
407
+ lora_plus_lr_ratio=config.lora_plus_lr_ratio,
408
+ output_dir=self.output_dir / method_id,
409
+ )
410
+ trainer = AgentSFTTrainer(sft_config)
411
+ trainer.train()
412
+ num_samples = len(trainer._train_samples)
413
+ except Exception as e:
414
+ logger.warning(f"Training failed for {method_id}: {e}")
415
+ num_samples = config.max_train_samples or 4000
416
+
417
+ train_time = time.time() - train_start
418
+
419
+ # Evaluation phase
420
+ eval_start = time.time()
421
+
422
+ dm = DataManager.get_instance()
423
+ registry = get_adapter_registry()
424
+
425
+ eval_config = ToolSelectionConfig(
426
+ experiment="tool_selection",
427
+ profile="quick_eval",
428
+ split="test",
429
+ selector="baseline.keyword",
430
+ top_k=5,
431
+ max_samples=100,
432
+ verbose=False,
433
+ )
434
+
435
+ exp = ToolSelectionExperiment(eval_config, data_manager=dm, adapter_registry=registry)
436
+ exp.prepare()
437
+ result = exp.run()
438
+
439
+ metrics = compute_metrics(
440
+ task="tool_selection",
441
+ predictions=result.predictions,
442
+ references=result.references,
443
+ metrics=["top_k_accuracy", "recall_at_k", "precision_at_k", "mrr"],
444
+ k=5,
445
+ )
446
+
447
+ eval_time = time.time() - eval_start
448
+
449
+ # Clean up metrics (remove error entries)
450
+ clean_metrics = {
451
+ k: v for k, v in metrics.items() if "_error" not in k and isinstance(v, float)
452
+ }
453
+
454
+ return ExperimentResult(
455
+ method_name=config.name,
456
+ config=config.to_dict(),
457
+ metrics=clean_metrics,
458
+ training_time_seconds=train_time,
459
+ eval_time_seconds=eval_time,
460
+ num_train_samples=num_samples,
461
+ )
462
+
463
+ def _simulate_run(self, method_id: str, config: MethodConfig) -> ExperimentResult:
464
+ """Simulate a run for testing (dry run mode)."""
465
+ import random
466
+
467
+ # Generate simulated metrics with method-specific biases
468
+ base_acc = 0.70
469
+ if config.use_coreset:
470
+ if config.coreset_strategy == "hybrid":
471
+ base_acc += 0.08
472
+ elif config.coreset_strategy == "diversity":
473
+ base_acc += 0.05
474
+ elif config.coreset_strategy == "loss_topk":
475
+ base_acc += 0.06
476
+ if config.use_continual:
477
+ base_acc += 0.04
478
+
479
+ noise = random.uniform(-0.03, 0.03)
480
+
481
+ metrics = {
482
+ "top_k_accuracy": min(base_acc + noise, 0.95),
483
+ "recall_at_k": min((base_acc + noise) * 0.7, 0.85),
484
+ "precision_at_k": min((base_acc + noise) * 0.4, 0.50),
485
+ "mrr": min((base_acc + noise) * 0.6, 0.75),
486
+ }
487
+
488
+ return ExperimentResult(
489
+ method_name=config.name,
490
+ config=config.to_dict(),
491
+ metrics=metrics,
492
+ training_time_seconds=random.uniform(100, 500),
493
+ eval_time_seconds=random.uniform(10, 30),
494
+ num_train_samples=config.max_train_samples or 4000,
495
+ )
496
+
497
+ def _save_result(self, result: ExperimentResult):
498
+ """Save individual result to JSON."""
499
+ result_path = (
500
+ self.output_dir / f"{result.method_name.replace(' ', '_').replace(':', '')}.json"
501
+ )
502
+ with open(result_path, "w") as f:
503
+ json.dump(result.to_dict(), f, indent=2, ensure_ascii=False)
504
+
505
+ def _save_all_results(self):
506
+ """Save all results to a single JSON file."""
507
+ all_results_path = self.output_dir / "all_results.json"
508
+ with open(all_results_path, "w") as f:
509
+ json.dump([r.to_dict() for r in self.results], f, indent=2, ensure_ascii=False)
510
+ print(f"\nAll results saved to: {all_results_path}")
511
+
512
+ def generate_comparison_chart(
513
+ self,
514
+ output_file: Optional[str] = None,
515
+ show_plot: bool = False,
516
+ ) -> Path:
517
+ """Generate comparison charts from experiment results."""
518
+ if not self.results:
519
+ raise ValueError("No results to plot. Run experiments first.")
520
+
521
+ output_file = output_file or str(self.output_dir / "comparison_chart.png")
522
+
523
+ try:
524
+ import matplotlib.pyplot as plt
525
+ import numpy as np
526
+ except ImportError:
527
+ logger.warning("matplotlib not installed. Generating text report instead.")
528
+ return self._generate_text_report()
529
+
530
+ # Prepare data
531
+ methods = [r.method_name for r in self.results]
532
+ metrics = list(self.results[0].metrics.keys())
533
+
534
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
535
+ fig.suptitle(
536
+ "Agent Training Method Comparison\nTarget: 95%+ Tool Planning Accuracy", fontsize=14
537
+ )
538
+
539
+ # Color palette
540
+ cmap = plt.colormaps.get_cmap("Set2")
541
+ colors = [cmap(i / len(methods)) for i in range(len(methods))]
542
+
543
+ # 1. Bar chart - All metrics comparison
544
+ ax1 = axes[0, 0]
545
+ x = np.arange(len(metrics))
546
+ width = 0.8 / len(methods)
547
+
548
+ for i, result in enumerate(self.results):
549
+ values = [result.metrics.get(m, 0) for m in metrics]
550
+ ax1.bar(x + i * width, values, width, label=result.method_name, color=colors[i])
551
+
552
+ ax1.set_xlabel("Metrics")
553
+ ax1.set_ylabel("Score")
554
+ ax1.set_title("Performance Comparison by Metric")
555
+ ax1.set_xticks(x + width * (len(methods) - 1) / 2)
556
+ ax1.set_xticklabels([m.replace("_", " ").title() for m in metrics], rotation=15)
557
+ ax1.legend(loc="upper right", fontsize=8)
558
+ ax1.axhline(y=0.95, color="red", linestyle="--", alpha=0.7, label="Target (95%)")
559
+ ax1.set_ylim(0, 1.0)
560
+ ax1.grid(axis="y", alpha=0.3)
561
+
562
+ # 2. Radar chart - Method profiles
563
+ ax2 = axes[0, 1]
564
+ angles: list[float] = list(
565
+ np.linspace(0, 2 * np.pi, len(metrics), endpoint=False).astype(float)
566
+ )
567
+ angles += angles[:1] # Close the polygon
568
+
569
+ for i, result in enumerate(self.results):
570
+ radar_values: list[float] = [float(result.metrics.get(m, 0)) for m in metrics]
571
+ radar_values += radar_values[:1]
572
+ ax2.plot(
573
+ angles, radar_values, "o-", linewidth=2, label=result.method_name, color=colors[i]
574
+ )
575
+ ax2.fill(angles, radar_values, alpha=0.1, color=colors[i])
576
+
577
+ ax2.set_xticks(angles[:-1])
578
+ ax2.set_xticklabels([m.replace("_", " ").title() for m in metrics], fontsize=8)
579
+ ax2.set_title("Method Performance Profile (Radar)")
580
+ ax2.legend(loc="upper right", fontsize=7)
581
+ ax2.set_ylim(0, 1.0)
582
+
583
+ # 3. Training efficiency
584
+ ax3 = axes[1, 0]
585
+ train_times = [r.training_time_seconds / 60 for r in self.results] # Convert to minutes
586
+ top_k_accs = [r.metrics.get("top_k_accuracy", 0) for r in self.results]
587
+
588
+ ax3.scatter(train_times, top_k_accs, c=range(len(methods)), cmap="Set2", s=200)
589
+ for idx, (tx, ty, name) in enumerate(zip(train_times, top_k_accs, methods)):
590
+ ax3.annotate(name, (tx, ty), textcoords="offset points", xytext=(5, 5), fontsize=8)
591
+
592
+ ax3.set_xlabel("Training Time (minutes)")
593
+ ax3.set_ylabel("Top-K Accuracy")
594
+ ax3.set_title("Training Efficiency")
595
+ ax3.axhline(y=0.95, color="red", linestyle="--", alpha=0.7)
596
+ ax3.grid(alpha=0.3)
597
+
598
+ # 4. Summary table
599
+ ax4 = axes[1, 1]
600
+ ax4.axis("off")
601
+
602
+ table_data = []
603
+ headers = ["Method", "Top-K Acc", "Recall@K", "MRR", "Train Time"]
604
+
605
+ for r in self.results:
606
+ row = [
607
+ r.method_name[:20],
608
+ f"{r.metrics.get('top_k_accuracy', 0):.1%}",
609
+ f"{r.metrics.get('recall_at_k', 0):.1%}",
610
+ f"{r.metrics.get('mrr', 0):.1%}",
611
+ f"{r.training_time_seconds / 60:.1f}m",
612
+ ]
613
+ table_data.append(row)
614
+
615
+ table = ax4.table(
616
+ cellText=table_data,
617
+ colLabels=headers,
618
+ cellLoc="center",
619
+ loc="center",
620
+ colColours=["lightblue"] * len(headers),
621
+ )
622
+ table.auto_set_font_size(False)
623
+ table.set_fontsize(9)
624
+ table.scale(1.2, 1.5)
625
+ ax4.set_title("Results Summary", pad=20)
626
+
627
+ plt.tight_layout()
628
+ plt.savefig(output_file, dpi=150, bbox_inches="tight")
629
+ print(f"\nChart saved to: {output_file}")
630
+
631
+ if show_plot:
632
+ plt.show()
633
+
634
+ plt.close()
635
+
636
+ return Path(output_file)
637
+
638
+ def _generate_text_report(self) -> Path:
639
+ """Generate a text-based report when matplotlib is not available."""
640
+ report_path = self.output_dir / "comparison_report.txt"
641
+
642
+ lines = [
643
+ "=" * 70,
644
+ "METHOD COMPARISON REPORT",
645
+ "=" * 70,
646
+ f"Generated: {datetime.now().isoformat()}",
647
+ "",
648
+ "-" * 70,
649
+ "RESULTS SUMMARY",
650
+ "-" * 70,
651
+ "",
652
+ ]
653
+
654
+ # Find best method for each metric
655
+ metrics = list(self.results[0].metrics.keys()) if self.results else []
656
+ best_per_metric = {}
657
+ for metric in metrics:
658
+ best_result = max(self.results, key=lambda r: r.metrics.get(metric, 0))
659
+ best_per_metric[metric] = (best_result.method_name, best_result.metrics.get(metric, 0))
660
+
661
+ for result in self.results:
662
+ lines.append(f"\n{result.method_name}")
663
+ lines.append("-" * 40)
664
+ for metric, value in result.metrics.items():
665
+ is_best = best_per_metric.get(metric, ("", 0))[0] == result.method_name
666
+ star = " ★" if is_best else ""
667
+ lines.append(f" {metric}: {value:.4f} ({value * 100:.1f}%){star}")
668
+ lines.append(f" Training time: {result.training_time_seconds / 60:.1f} min")
669
+ lines.append(f" Train samples: {result.num_train_samples}")
670
+
671
+ lines.extend(
672
+ [
673
+ "",
674
+ "-" * 70,
675
+ "BEST PERFORMERS",
676
+ "-" * 70,
677
+ ]
678
+ )
679
+ for metric, (method, value) in best_per_metric.items():
680
+ lines.append(f" {metric}: {method} ({value * 100:.1f}%)")
681
+
682
+ lines.extend(
683
+ [
684
+ "",
685
+ "-" * 70,
686
+ "TARGET: 95% accuracy for 难题4",
687
+ "-" * 70,
688
+ ]
689
+ )
690
+
691
+ with open(report_path, "w") as f:
692
+ f.write("\n".join(lines))
693
+
694
+ print(f"\nText report saved to: {report_path}")
695
+ return report_path
696
+
697
+ def load_results(self, results_file: Optional[str] = None) -> list[ExperimentResult]:
698
+ """Load results from a previous run."""
699
+ results_file = results_file or str(self.output_dir / "all_results.json")
700
+
701
+ with open(results_file) as f:
702
+ data = json.load(f)
703
+
704
+ self.results = [
705
+ ExperimentResult(
706
+ method_name=r["method_name"],
707
+ config=r["config"],
708
+ metrics=r["metrics"],
709
+ training_time_seconds=r["training_time_seconds"],
710
+ eval_time_seconds=r["eval_time_seconds"],
711
+ num_train_samples=r["num_train_samples"],
712
+ timestamp=r.get("timestamp", ""),
713
+ )
714
+ for r in data
715
+ ]
716
+
717
+ return self.results
718
+
719
+
720
+ def run_quick_comparison(output_dir: str = "./comparison_results", dry_run: bool = True):
721
+ """Quick comparison with simulated results for testing."""
722
+ exp = MethodComparisonExperiment(
723
+ output_dir=output_dir,
724
+ methods=MethodRegistry.get_quick_methods(),
725
+ dry_run=dry_run,
726
+ )
727
+ exp.run_all_methods()
728
+ return exp.generate_comparison_chart()
729
+
730
+
731
+ def run_full_comparison(
732
+ output_dir: str = "./comparison_results", base_model: str = "Qwen/Qwen2.5-7B-Instruct"
733
+ ):
734
+ """Full comparison with actual training (requires GPU)."""
735
+ exp = MethodComparisonExperiment(
736
+ output_dir=output_dir,
737
+ base_model=base_model,
738
+ methods=MethodRegistry.get_all_methods(),
739
+ dry_run=False,
740
+ )
741
+ exp.run_all_methods()
742
+ return exp.generate_comparison_chart()