isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Planning Experiment
|
|
3
|
+
|
|
4
|
+
Experiment runner for evaluating agent task planning capabilities.
|
|
5
|
+
Tests ability to generate multi-step plans with tool sequences.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
13
|
+
BaseExperiment,
|
|
14
|
+
ExperimentResult,
|
|
15
|
+
PlanningConfig,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PlanningTask:
|
|
20
|
+
"""Task for planning."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self, sample_id: str, instruction: str, context: dict[str, Any], available_tools: list[str]
|
|
24
|
+
):
|
|
25
|
+
self.sample_id = sample_id
|
|
26
|
+
self.instruction = instruction
|
|
27
|
+
self.context = context
|
|
28
|
+
self.available_tools = available_tools
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PlanningSample:
|
|
32
|
+
"""Sample for planning evaluation."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, data: dict[str, Any]):
|
|
35
|
+
self.sample_id = data.get("sample_id", "")
|
|
36
|
+
self.instruction = data.get("instruction", "")
|
|
37
|
+
self.context = data.get("context", {})
|
|
38
|
+
self.available_tools = data.get("available_tools", [])
|
|
39
|
+
self.ground_truth_steps = data.get("ground_truth_steps", [])
|
|
40
|
+
self.tool_sequence = data.get("tool_sequence", [])
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class LocalPlanningDataLoader:
|
|
44
|
+
"""Load planning data from local JSONL files."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, data_dir: Path):
|
|
47
|
+
self.data_dir = data_dir
|
|
48
|
+
|
|
49
|
+
def iter_split(self, task_type: str = "task_planning", split: str = "test"):
|
|
50
|
+
"""Iterate over samples in the specified split."""
|
|
51
|
+
split_file = self.data_dir / f"{split}.jsonl"
|
|
52
|
+
if not split_file.exists():
|
|
53
|
+
raise FileNotFoundError(f"Split file not found: {split_file}")
|
|
54
|
+
|
|
55
|
+
with open(split_file, encoding="utf-8") as f:
|
|
56
|
+
for line in f:
|
|
57
|
+
if line.strip():
|
|
58
|
+
data = json.loads(line)
|
|
59
|
+
yield PlanningSample(data)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class PlanningExperiment(BaseExperiment):
|
|
63
|
+
"""
|
|
64
|
+
Experiment for task planning evaluation.
|
|
65
|
+
|
|
66
|
+
Workflow:
|
|
67
|
+
1. Load planning benchmark samples
|
|
68
|
+
2. For each task, call planner strategy to generate plan
|
|
69
|
+
3. Collect plan predictions and ground truth sequences
|
|
70
|
+
4. Return ExperimentResult for evaluation
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self, config: PlanningConfig, data_manager: Any = None, adapter_registry: Any = None
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
Initialize planning experiment.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
config: Planning configuration
|
|
81
|
+
data_manager: DataManager for data loading
|
|
82
|
+
adapter_registry: Registry containing planner strategies
|
|
83
|
+
"""
|
|
84
|
+
super().__init__(config, data_manager, adapter_registry)
|
|
85
|
+
self.config: PlanningConfig = config
|
|
86
|
+
self._local_data_dir: Path | None = None
|
|
87
|
+
|
|
88
|
+
def set_local_data_dir(self, data_dir: str | Path) -> None:
|
|
89
|
+
"""Set local data directory for loading data."""
|
|
90
|
+
self._local_data_dir = Path(data_dir)
|
|
91
|
+
|
|
92
|
+
def prepare(self):
|
|
93
|
+
"""Prepare experiment: load data and initialize planner."""
|
|
94
|
+
# Don't call super().prepare() to avoid DataManager requirement
|
|
95
|
+
verbose = getattr(self.config, "verbose", False)
|
|
96
|
+
|
|
97
|
+
if verbose:
|
|
98
|
+
print(f"\n{'=' * 60}")
|
|
99
|
+
print(f"Planning Experiment: {self.experiment_id}")
|
|
100
|
+
print(f"{'=' * 60}")
|
|
101
|
+
print(f"Profile: {self.config.profile}")
|
|
102
|
+
print(f"Split: {self.config.split}")
|
|
103
|
+
print(f"Planner: {self.config.planner}")
|
|
104
|
+
print(f"Max steps: {self.config.max_steps}")
|
|
105
|
+
|
|
106
|
+
# Try local data first
|
|
107
|
+
if self._local_data_dir is not None:
|
|
108
|
+
self.benchmark_loader = LocalPlanningDataLoader(self._local_data_dir)
|
|
109
|
+
if verbose:
|
|
110
|
+
print(f"✓ Loaded local data from: {self._local_data_dir}")
|
|
111
|
+
elif self.dm is not None:
|
|
112
|
+
# Load data through DataManager
|
|
113
|
+
try:
|
|
114
|
+
agent_eval = self.dm.get_by_usage("agent_eval")
|
|
115
|
+
profile_data = agent_eval.load_profile(self.config.profile)
|
|
116
|
+
|
|
117
|
+
self.benchmark_loader = profile_data.get("benchmark")
|
|
118
|
+
self.tools_loader = profile_data.get("tools")
|
|
119
|
+
|
|
120
|
+
if verbose:
|
|
121
|
+
print("✓ Loaded benchmark data via DataManager")
|
|
122
|
+
|
|
123
|
+
except Exception as e:
|
|
124
|
+
print(f"Warning: Could not load data via DataManager: {e}")
|
|
125
|
+
# Try default local path
|
|
126
|
+
default_path = (
|
|
127
|
+
Path(__file__).parent.parent.parent.parent.parent.parent
|
|
128
|
+
/ "data"
|
|
129
|
+
/ "task_planning"
|
|
130
|
+
)
|
|
131
|
+
if default_path.exists():
|
|
132
|
+
self.benchmark_loader = LocalPlanningDataLoader(default_path)
|
|
133
|
+
if verbose:
|
|
134
|
+
print(f"✓ Loaded local data from default path: {default_path}")
|
|
135
|
+
else:
|
|
136
|
+
# Try default local path
|
|
137
|
+
default_path = (
|
|
138
|
+
Path(__file__).parent.parent.parent.parent.parent.parent / "data" / "task_planning"
|
|
139
|
+
)
|
|
140
|
+
if default_path.exists():
|
|
141
|
+
self.benchmark_loader = LocalPlanningDataLoader(default_path)
|
|
142
|
+
if verbose:
|
|
143
|
+
print(f"✓ Loaded local data from default path: {default_path}")
|
|
144
|
+
|
|
145
|
+
# Initialize planner strategy
|
|
146
|
+
if self.adapter_registry is not None:
|
|
147
|
+
try:
|
|
148
|
+
self.strategy = self.adapter_registry.get(self.config.planner)
|
|
149
|
+
if verbose:
|
|
150
|
+
print(f"✓ Initialized planner: {self.config.planner}")
|
|
151
|
+
except Exception as e:
|
|
152
|
+
print(f"Warning: Could not load planner: {e}")
|
|
153
|
+
self.strategy = None
|
|
154
|
+
else:
|
|
155
|
+
self.strategy = None
|
|
156
|
+
|
|
157
|
+
def run(self) -> ExperimentResult:
|
|
158
|
+
"""
|
|
159
|
+
Run planning experiment.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
ExperimentResult with plan predictions and references
|
|
163
|
+
"""
|
|
164
|
+
verbose = getattr(self.config, "verbose", False)
|
|
165
|
+
|
|
166
|
+
if verbose:
|
|
167
|
+
print("\nRunning experiment...")
|
|
168
|
+
|
|
169
|
+
predictions: list[dict[str, Any]] = []
|
|
170
|
+
references: list[dict[str, Any]] = []
|
|
171
|
+
metadata: dict[str, Any] = {"total_samples": 0, "failed_samples": 0, "avg_plan_length": 0.0}
|
|
172
|
+
|
|
173
|
+
if self.benchmark_loader is None:
|
|
174
|
+
print("Error: No data loader available")
|
|
175
|
+
return self._create_result(
|
|
176
|
+
predictions=predictions, references=references, metadata=metadata
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
samples = self.benchmark_loader.iter_split(
|
|
181
|
+
task_type="task_planning", split=self.config.split
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
total_plan_length = 0
|
|
185
|
+
|
|
186
|
+
for idx, sample in enumerate(samples):
|
|
187
|
+
if self.config.max_samples and idx >= self.config.max_samples:
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
metadata["total_samples"] += 1
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
# Create planning task
|
|
194
|
+
task = PlanningTask(
|
|
195
|
+
sample_id=sample.sample_id,
|
|
196
|
+
instruction=sample.instruction,
|
|
197
|
+
context=sample.context if hasattr(sample, "context") else {},
|
|
198
|
+
available_tools=(
|
|
199
|
+
sample.available_tools if hasattr(sample, "available_tools") else []
|
|
200
|
+
),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Get prediction from strategy
|
|
204
|
+
if self.strategy is not None:
|
|
205
|
+
plan = self.strategy.plan(task)
|
|
206
|
+
|
|
207
|
+
pred_dict = {
|
|
208
|
+
"sample_id": sample.sample_id,
|
|
209
|
+
"plan_steps": [
|
|
210
|
+
{
|
|
211
|
+
"step_id": step.step_id,
|
|
212
|
+
"description": step.description,
|
|
213
|
+
"tool_id": step.tool_id,
|
|
214
|
+
"confidence": step.confidence,
|
|
215
|
+
}
|
|
216
|
+
for step in plan.steps
|
|
217
|
+
],
|
|
218
|
+
"tool_sequence": plan.tool_sequence,
|
|
219
|
+
}
|
|
220
|
+
total_plan_length += len(plan.steps)
|
|
221
|
+
else:
|
|
222
|
+
pred_dict = {
|
|
223
|
+
"sample_id": sample.sample_id,
|
|
224
|
+
"plan_steps": [],
|
|
225
|
+
"tool_sequence": [],
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
predictions.append(pred_dict)
|
|
229
|
+
|
|
230
|
+
# Get ground truth from sample data
|
|
231
|
+
ref_dict = {
|
|
232
|
+
"sample_id": sample.sample_id,
|
|
233
|
+
"plan_steps": sample.ground_truth_steps,
|
|
234
|
+
"tool_sequence": sample.tool_sequence,
|
|
235
|
+
}
|
|
236
|
+
references.append(ref_dict)
|
|
237
|
+
|
|
238
|
+
if verbose and (idx + 1) % 10 == 0:
|
|
239
|
+
print(f" Processed {idx + 1} samples...")
|
|
240
|
+
|
|
241
|
+
except Exception as e:
|
|
242
|
+
metadata["failed_samples"] += 1
|
|
243
|
+
if verbose:
|
|
244
|
+
print(f" Error processing sample {idx}: {e}")
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
# Calculate average plan length
|
|
248
|
+
if metadata["total_samples"] > 0:
|
|
249
|
+
metadata["avg_plan_length"] = total_plan_length / metadata["total_samples"]
|
|
250
|
+
|
|
251
|
+
except Exception as e:
|
|
252
|
+
print(f"Error iterating samples: {e}")
|
|
253
|
+
|
|
254
|
+
if verbose:
|
|
255
|
+
print("\nCompleted:")
|
|
256
|
+
print(f" Total samples: {metadata['total_samples']}")
|
|
257
|
+
print(f" Failed samples: {metadata['failed_samples']}")
|
|
258
|
+
print(f" Avg plan length: {metadata['avg_plan_length']:.1f}")
|
|
259
|
+
|
|
260
|
+
return self._create_result(
|
|
261
|
+
predictions=predictions, references=references, metadata=metadata
|
|
262
|
+
)
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Timing Detection Experiment
|
|
3
|
+
|
|
4
|
+
Experiment runner for evaluating agent timing judgment capabilities.
|
|
5
|
+
Tests ability to decide when to invoke tools versus answering directly.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Optional
|
|
9
|
+
|
|
10
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
11
|
+
BaseExperiment,
|
|
12
|
+
ExperimentResult,
|
|
13
|
+
TimingDetectionConfig,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TimingMessage:
|
|
18
|
+
"""Message for timing judgment."""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
sample_id: str,
|
|
23
|
+
message: str,
|
|
24
|
+
context: dict[str, Any],
|
|
25
|
+
direct_answer: Optional[str] = None,
|
|
26
|
+
):
|
|
27
|
+
self.sample_id = sample_id
|
|
28
|
+
self.message = message
|
|
29
|
+
self.context = context
|
|
30
|
+
self.direct_answer = direct_answer
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TimingDetectionExperiment(BaseExperiment):
|
|
34
|
+
"""
|
|
35
|
+
Experiment for timing judgment evaluation.
|
|
36
|
+
|
|
37
|
+
Workflow:
|
|
38
|
+
1. Load timing judgment benchmark samples
|
|
39
|
+
2. For each message, call detector to decide whether to call tool
|
|
40
|
+
3. Collect decisions and ground truth
|
|
41
|
+
4. Return ExperimentResult for evaluation
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self, config: TimingDetectionConfig, data_manager: Any = None, adapter_registry: Any = None
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Initialize timing detection experiment.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config: Timing detection configuration
|
|
52
|
+
data_manager: DataManager for data loading
|
|
53
|
+
adapter_registry: Registry containing detector strategies
|
|
54
|
+
"""
|
|
55
|
+
super().__init__(config, data_manager, adapter_registry)
|
|
56
|
+
self.config: TimingDetectionConfig = config
|
|
57
|
+
|
|
58
|
+
def prepare(self):
|
|
59
|
+
"""Prepare experiment: load data and initialize detector."""
|
|
60
|
+
super().prepare()
|
|
61
|
+
|
|
62
|
+
verbose = getattr(self.config, "verbose", False)
|
|
63
|
+
if verbose:
|
|
64
|
+
print(f"\n{'=' * 60}")
|
|
65
|
+
print(f"Timing Detection Experiment: {self.experiment_id}")
|
|
66
|
+
print(f"{'=' * 60}")
|
|
67
|
+
print(f"Profile: {self.config.profile}")
|
|
68
|
+
print(f"Split: {self.config.split}")
|
|
69
|
+
print(f"Detector: {self.config.detector}")
|
|
70
|
+
print(f"Threshold: {self.config.threshold}")
|
|
71
|
+
|
|
72
|
+
# Load data through DataManager
|
|
73
|
+
try:
|
|
74
|
+
agent_eval = self.dm.get_by_usage("agent_eval")
|
|
75
|
+
profile_data = agent_eval.load_profile(self.config.profile)
|
|
76
|
+
|
|
77
|
+
self.benchmark_loader = profile_data.get("benchmark")
|
|
78
|
+
|
|
79
|
+
if verbose:
|
|
80
|
+
print("✓ Loaded benchmark data")
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
print(f"Warning: Could not load data: {e}")
|
|
84
|
+
|
|
85
|
+
# Initialize detector strategy
|
|
86
|
+
if self.adapter_registry is not None:
|
|
87
|
+
try:
|
|
88
|
+
self.strategy = self.adapter_registry.get(self.config.detector)
|
|
89
|
+
if verbose:
|
|
90
|
+
print(f"✓ Initialized detector: {self.config.detector}")
|
|
91
|
+
except Exception as e:
|
|
92
|
+
print(f"Warning: Could not load detector: {e}")
|
|
93
|
+
self.strategy = None
|
|
94
|
+
else:
|
|
95
|
+
self.strategy = None
|
|
96
|
+
|
|
97
|
+
def run(self) -> ExperimentResult:
|
|
98
|
+
"""
|
|
99
|
+
Run timing detection experiment.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
ExperimentResult with timing decisions and references
|
|
103
|
+
"""
|
|
104
|
+
verbose = getattr(self.config, "verbose", False)
|
|
105
|
+
if verbose:
|
|
106
|
+
print("\nRunning experiment...")
|
|
107
|
+
|
|
108
|
+
predictions = []
|
|
109
|
+
references = []
|
|
110
|
+
metadata = {
|
|
111
|
+
"total_samples": 0,
|
|
112
|
+
"failed_samples": 0,
|
|
113
|
+
"positive_predictions": 0,
|
|
114
|
+
"negative_predictions": 0,
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
samples = self.benchmark_loader.iter_split(
|
|
119
|
+
task_type="timing_judgment", split=self.config.split
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
for idx, sample in enumerate(samples):
|
|
123
|
+
if self.config.max_samples and idx >= self.config.max_samples:
|
|
124
|
+
break
|
|
125
|
+
|
|
126
|
+
metadata["total_samples"] += 1
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
# Create timing message
|
|
130
|
+
message = TimingMessage(
|
|
131
|
+
sample_id=sample.sample_id,
|
|
132
|
+
message=sample.instruction,
|
|
133
|
+
context=sample.context if hasattr(sample, "context") else {},
|
|
134
|
+
direct_answer=(
|
|
135
|
+
sample.direct_answer if hasattr(sample, "direct_answer") else None
|
|
136
|
+
),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Get prediction from strategy
|
|
140
|
+
if self.strategy is not None:
|
|
141
|
+
decision = self.strategy.decide(message)
|
|
142
|
+
|
|
143
|
+
pred_dict = {
|
|
144
|
+
"sample_id": sample.sample_id,
|
|
145
|
+
"should_call_tool": decision.should_call_tool,
|
|
146
|
+
"confidence": decision.confidence,
|
|
147
|
+
"reasoning": decision.reasoning,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
if decision.should_call_tool:
|
|
151
|
+
metadata["positive_predictions"] += 1
|
|
152
|
+
else:
|
|
153
|
+
metadata["negative_predictions"] += 1
|
|
154
|
+
else:
|
|
155
|
+
pred_dict = {
|
|
156
|
+
"sample_id": sample.sample_id,
|
|
157
|
+
"should_call_tool": False,
|
|
158
|
+
"confidence": 0.5,
|
|
159
|
+
"reasoning": None,
|
|
160
|
+
}
|
|
161
|
+
metadata["negative_predictions"] += 1
|
|
162
|
+
|
|
163
|
+
predictions.append(pred_dict)
|
|
164
|
+
|
|
165
|
+
# Get ground truth
|
|
166
|
+
gt = sample.get_typed_ground_truth()
|
|
167
|
+
ref_dict = {
|
|
168
|
+
"sample_id": sample.sample_id,
|
|
169
|
+
"should_call_tool": gt.should_call_tool,
|
|
170
|
+
"reasoning_chain": (
|
|
171
|
+
gt.reasoning_chain if hasattr(gt, "reasoning_chain") else None
|
|
172
|
+
),
|
|
173
|
+
"direct_answer": gt.direct_answer if hasattr(gt, "direct_answer") else None,
|
|
174
|
+
}
|
|
175
|
+
references.append(ref_dict)
|
|
176
|
+
|
|
177
|
+
if verbose and (idx + 1) % 10 == 0:
|
|
178
|
+
print(f" Processed {idx + 1} samples...")
|
|
179
|
+
|
|
180
|
+
except Exception as e:
|
|
181
|
+
metadata["failed_samples"] += 1
|
|
182
|
+
if verbose:
|
|
183
|
+
print(f" Error processing sample {idx}: {e}")
|
|
184
|
+
continue
|
|
185
|
+
|
|
186
|
+
except Exception as e:
|
|
187
|
+
print(f"Error iterating samples: {e}")
|
|
188
|
+
|
|
189
|
+
if verbose:
|
|
190
|
+
print("\nCompleted:")
|
|
191
|
+
print(f" Total samples: {metadata['total_samples']}")
|
|
192
|
+
print(f" Failed samples: {metadata['failed_samples']}")
|
|
193
|
+
print(f" Positive predictions: {metadata['positive_predictions']}")
|
|
194
|
+
print(f" Negative predictions: {metadata['negative_predictions']}")
|
|
195
|
+
|
|
196
|
+
return self._create_result(
|
|
197
|
+
predictions=predictions, references=references, metadata=metadata
|
|
198
|
+
)
|