isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Evaluation module for Agent Capability Benchmark.
|
|
3
|
+
|
|
4
|
+
This module provides metrics, analyzers, and report builders for evaluating
|
|
5
|
+
agent performance across three capabilities: tool selection, task planning,
|
|
6
|
+
and timing judgment.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List, Optional, Protocol, Sequence
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"MetricOutput",
|
|
16
|
+
"EvaluationReport",
|
|
17
|
+
"Metric",
|
|
18
|
+
"Analyzer",
|
|
19
|
+
"ReportBuilder",
|
|
20
|
+
"compute_metrics",
|
|
21
|
+
"MetricRegistry",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MetricOutput(BaseModel):
|
|
26
|
+
"""Output from a metric computation."""
|
|
27
|
+
|
|
28
|
+
value: float
|
|
29
|
+
details: dict[str, Any] = Field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class EvaluationReport(BaseModel):
|
|
33
|
+
"""Complete evaluation report with metrics, breakdowns, and artifacts."""
|
|
34
|
+
|
|
35
|
+
task: str
|
|
36
|
+
experiment_id: str
|
|
37
|
+
metrics: dict[str, float]
|
|
38
|
+
breakdowns: dict[str, Any] = Field(default_factory=dict)
|
|
39
|
+
artifacts: dict[str, Path] = Field(default_factory=dict)
|
|
40
|
+
timestamp: str
|
|
41
|
+
|
|
42
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Metric(Protocol):
|
|
46
|
+
"""Protocol for metric implementations."""
|
|
47
|
+
|
|
48
|
+
name: str
|
|
49
|
+
|
|
50
|
+
def compute(self, predictions: Sequence[Any], references: Sequence[Any]) -> MetricOutput:
|
|
51
|
+
"""
|
|
52
|
+
Compute metric from predictions and references.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
predictions: Model predictions
|
|
56
|
+
references: Ground truth references
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
MetricOutput with value and optional details
|
|
60
|
+
"""
|
|
61
|
+
...
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Analyzer(Protocol):
|
|
65
|
+
"""Protocol for analyzer implementations."""
|
|
66
|
+
|
|
67
|
+
name: str
|
|
68
|
+
|
|
69
|
+
def analyze(
|
|
70
|
+
self, predictions: Sequence[Any], references: Sequence[Any], metadata: dict[str, Any]
|
|
71
|
+
) -> dict[str, Any]:
|
|
72
|
+
"""
|
|
73
|
+
Analyze predictions and produce breakdowns.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
predictions: Model predictions
|
|
77
|
+
references: Ground truth references
|
|
78
|
+
metadata: Additional context from experiment
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Dictionary with analysis results
|
|
82
|
+
"""
|
|
83
|
+
...
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ReportBuilder(Protocol):
|
|
87
|
+
"""Protocol for report builder implementations."""
|
|
88
|
+
|
|
89
|
+
def build(self, report: EvaluationReport, output_path: Path) -> Path:
|
|
90
|
+
"""
|
|
91
|
+
Build and save report to file.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
report: EvaluationReport to format
|
|
95
|
+
output_path: Path to save report
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Path to saved report file
|
|
99
|
+
"""
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# Import metric registry after defining base classes
|
|
104
|
+
from sage.benchmark.benchmark_agent.evaluation.metrics import MetricRegistry
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def compute_metrics(
|
|
108
|
+
task: str,
|
|
109
|
+
predictions: list[dict[str, Any]],
|
|
110
|
+
references: list[dict[str, Any]],
|
|
111
|
+
metrics: list[str],
|
|
112
|
+
k: int = 5,
|
|
113
|
+
) -> dict[str, float]:
|
|
114
|
+
"""
|
|
115
|
+
Compute evaluation metrics for experiment results.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
task: Task type ('tool_selection', 'planning', 'timing_detection')
|
|
119
|
+
predictions: List of prediction dictionaries
|
|
120
|
+
references: List of reference dictionaries
|
|
121
|
+
metrics: List of metric names to compute
|
|
122
|
+
k: Top-k parameter for ranking metrics
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Dictionary mapping metric names to values
|
|
126
|
+
"""
|
|
127
|
+
results = {}
|
|
128
|
+
|
|
129
|
+
if task == "tool_selection":
|
|
130
|
+
# Extract tool lists from predictions and references
|
|
131
|
+
pred_tools = []
|
|
132
|
+
ref_tools = []
|
|
133
|
+
|
|
134
|
+
for pred, ref in zip(predictions, references):
|
|
135
|
+
# Get predicted tool IDs
|
|
136
|
+
if "predicted_tools" in pred:
|
|
137
|
+
tools = pred["predicted_tools"]
|
|
138
|
+
if tools and isinstance(tools[0], dict):
|
|
139
|
+
pred_tools.append([t["tool_id"] for t in tools])
|
|
140
|
+
else:
|
|
141
|
+
pred_tools.append(tools if tools else [])
|
|
142
|
+
else:
|
|
143
|
+
pred_tools.append([])
|
|
144
|
+
|
|
145
|
+
# Get reference tool IDs
|
|
146
|
+
if "ground_truth_tools" in ref:
|
|
147
|
+
ref_tools.append(ref["ground_truth_tools"])
|
|
148
|
+
elif "top_k" in ref:
|
|
149
|
+
ref_tools.append(ref["top_k"])
|
|
150
|
+
else:
|
|
151
|
+
ref_tools.append([])
|
|
152
|
+
|
|
153
|
+
# Compute each metric
|
|
154
|
+
for metric_name in metrics:
|
|
155
|
+
try:
|
|
156
|
+
if metric_name in ("top_k_accuracy", "recall_at_k", "precision_at_k"):
|
|
157
|
+
metric = MetricRegistry.get(metric_name, k=k)
|
|
158
|
+
elif metric_name == "mrr":
|
|
159
|
+
metric = MetricRegistry.get("mrr")
|
|
160
|
+
else:
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
output = metric.compute(pred_tools, ref_tools)
|
|
164
|
+
results[metric_name] = output.value
|
|
165
|
+
except Exception as e:
|
|
166
|
+
results[metric_name] = 0.0
|
|
167
|
+
results[f"{metric_name}_error"] = str(e)
|
|
168
|
+
|
|
169
|
+
elif task == "timing_detection":
|
|
170
|
+
# Extract boolean decisions
|
|
171
|
+
pred_decisions = []
|
|
172
|
+
ref_decisions = []
|
|
173
|
+
|
|
174
|
+
for pred, ref in zip(predictions, references):
|
|
175
|
+
pred_decisions.append(pred.get("should_call_tool", False))
|
|
176
|
+
ref_decisions.append(ref.get("should_call_tool", False))
|
|
177
|
+
|
|
178
|
+
# Metric name mapping for timing detection
|
|
179
|
+
timing_metric_map = {
|
|
180
|
+
"accuracy": "timing_accuracy",
|
|
181
|
+
"precision": "timing_precision",
|
|
182
|
+
"recall": "timing_recall",
|
|
183
|
+
"f1": "timing_f1",
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
for metric_name in metrics:
|
|
187
|
+
try:
|
|
188
|
+
# Map simple names to full metric names
|
|
189
|
+
registry_name = timing_metric_map.get(metric_name, metric_name)
|
|
190
|
+
metric = MetricRegistry.get(registry_name)
|
|
191
|
+
output = metric.compute(pred_decisions, ref_decisions)
|
|
192
|
+
results[metric_name] = output.value
|
|
193
|
+
# Include details if available
|
|
194
|
+
if hasattr(output, "details") and output.details:
|
|
195
|
+
results[f"{metric_name}_details"] = output.details
|
|
196
|
+
except Exception as e:
|
|
197
|
+
results[metric_name] = 0.0
|
|
198
|
+
results[f"{metric_name}_error"] = str(e)
|
|
199
|
+
|
|
200
|
+
elif task == "planning":
|
|
201
|
+
# Extract tool sequences
|
|
202
|
+
pred_sequences = []
|
|
203
|
+
ref_sequences = []
|
|
204
|
+
|
|
205
|
+
for pred, ref in zip(predictions, references):
|
|
206
|
+
pred_sequences.append(pred.get("tool_sequence", []))
|
|
207
|
+
ref_sequences.append(ref.get("tool_sequence", []))
|
|
208
|
+
|
|
209
|
+
for metric_name in metrics:
|
|
210
|
+
try:
|
|
211
|
+
metric = MetricRegistry.get(metric_name)
|
|
212
|
+
output = metric.compute(pred_sequences, ref_sequences)
|
|
213
|
+
results[metric_name] = output.value
|
|
214
|
+
except Exception:
|
|
215
|
+
results[metric_name] = 0.0
|
|
216
|
+
|
|
217
|
+
return results
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Analyzers package initialization."""
|
|
2
|
+
|
|
3
|
+
from .planning_analyzer import PlanningAnalyzer
|
|
4
|
+
from .timing_analyzer import TimingAnalyzer
|
|
5
|
+
from .tool_selection_analyzer import ToolSelectionAnalyzer
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"ToolSelectionAnalyzer",
|
|
9
|
+
"PlanningAnalyzer",
|
|
10
|
+
"TimingAnalyzer",
|
|
11
|
+
]
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Planning analyzer for step-level alignment analysis."""
|
|
2
|
+
|
|
3
|
+
from collections import Counter, defaultdict
|
|
4
|
+
from typing import Any, Sequence
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PlanningAnalyzer:
|
|
8
|
+
"""
|
|
9
|
+
Analyzer for task planning predictions.
|
|
10
|
+
|
|
11
|
+
Provides breakdowns by:
|
|
12
|
+
- Step-level correctness
|
|
13
|
+
- Tool sequence alignment
|
|
14
|
+
- Failure patterns
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
name = "planning"
|
|
18
|
+
|
|
19
|
+
def analyze(
|
|
20
|
+
self,
|
|
21
|
+
predictions: Sequence[list[str]],
|
|
22
|
+
references: Sequence[list[str]],
|
|
23
|
+
metadata: dict[str, Any],
|
|
24
|
+
) -> dict[str, Any]:
|
|
25
|
+
"""
|
|
26
|
+
Analyze planning predictions.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
predictions: List of predicted tool sequences
|
|
30
|
+
references: List of reference tool sequences
|
|
31
|
+
metadata: Additional context
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Dictionary with analysis results
|
|
35
|
+
"""
|
|
36
|
+
# Step-level analysis
|
|
37
|
+
step_correctness: list[float] = []
|
|
38
|
+
length_diffs: list[int] = []
|
|
39
|
+
failure_modes: dict[str, int] = defaultdict(int)
|
|
40
|
+
|
|
41
|
+
exact_matches = 0
|
|
42
|
+
prefix_matches = 0
|
|
43
|
+
|
|
44
|
+
for pred, ref in zip(predictions, references):
|
|
45
|
+
# Length analysis
|
|
46
|
+
length_diffs.append(len(pred) - len(ref))
|
|
47
|
+
|
|
48
|
+
# Exact match
|
|
49
|
+
if pred == ref:
|
|
50
|
+
exact_matches += 1
|
|
51
|
+
step_correctness.append(1.0)
|
|
52
|
+
failure_modes["perfect"] += 1
|
|
53
|
+
prefix_matches += 1
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
# Prefix match
|
|
57
|
+
min_len = min(len(pred), len(ref))
|
|
58
|
+
if min_len > 0 and pred[:min_len] == ref[:min_len]:
|
|
59
|
+
prefix_matches += 1
|
|
60
|
+
|
|
61
|
+
# Step-by-step correctness
|
|
62
|
+
correct_steps = sum(1 for p, r in zip(pred, ref) if p == r)
|
|
63
|
+
step_acc = correct_steps / len(ref) if len(ref) > 0 else 0.0
|
|
64
|
+
step_correctness.append(step_acc)
|
|
65
|
+
|
|
66
|
+
# Classify failure mode
|
|
67
|
+
if len(pred) == 0:
|
|
68
|
+
failure_modes["empty_plan"] += 1
|
|
69
|
+
elif len(pred) < len(ref):
|
|
70
|
+
failure_modes["too_short"] += 1
|
|
71
|
+
elif len(pred) > len(ref):
|
|
72
|
+
failure_modes["too_long"] += 1
|
|
73
|
+
elif set(pred) == set(ref):
|
|
74
|
+
failure_modes["wrong_order"] += 1
|
|
75
|
+
else:
|
|
76
|
+
failure_modes["wrong_tools"] += 1
|
|
77
|
+
|
|
78
|
+
# Tool sequence statistics
|
|
79
|
+
pred_lengths = [len(p) for p in predictions]
|
|
80
|
+
ref_lengths = [len(r) for r in references]
|
|
81
|
+
|
|
82
|
+
# Tool usage analysis
|
|
83
|
+
tool_usage_pred: Counter[str] = Counter()
|
|
84
|
+
tool_usage_ref: Counter[str] = Counter()
|
|
85
|
+
for pred in predictions:
|
|
86
|
+
tool_usage_pred.update(pred)
|
|
87
|
+
for ref in references:
|
|
88
|
+
tool_usage_ref.update(ref)
|
|
89
|
+
|
|
90
|
+
return {
|
|
91
|
+
"exact_match_rate": exact_matches / len(predictions) if predictions else 0.0,
|
|
92
|
+
"prefix_match_rate": prefix_matches / len(predictions) if predictions else 0.0,
|
|
93
|
+
"step_correctness": {
|
|
94
|
+
"mean": sum(step_correctness) / len(step_correctness) if step_correctness else 0.0,
|
|
95
|
+
"min": min(step_correctness) if step_correctness else 0.0,
|
|
96
|
+
"max": max(step_correctness) if step_correctness else 0.0,
|
|
97
|
+
"distribution": step_correctness,
|
|
98
|
+
},
|
|
99
|
+
"length_analysis": {
|
|
100
|
+
"pred_avg": sum(pred_lengths) / len(pred_lengths) if pred_lengths else 0.0,
|
|
101
|
+
"ref_avg": sum(ref_lengths) / len(ref_lengths) if ref_lengths else 0.0,
|
|
102
|
+
"length_diff_mean": sum(length_diffs) / len(length_diffs) if length_diffs else 0.0,
|
|
103
|
+
"length_diff_distribution": length_diffs,
|
|
104
|
+
},
|
|
105
|
+
"failure_modes": dict(failure_modes),
|
|
106
|
+
"tool_usage": {
|
|
107
|
+
"predicted_most_common": tool_usage_pred.most_common(10),
|
|
108
|
+
"reference_most_common": tool_usage_ref.most_common(10),
|
|
109
|
+
},
|
|
110
|
+
"total_samples": len(predictions),
|
|
111
|
+
}
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Timing analyzer for confusion matrix and threshold analysis."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TimingAnalyzer:
|
|
9
|
+
"""
|
|
10
|
+
Analyzer for timing judgment predictions.
|
|
11
|
+
|
|
12
|
+
Provides breakdowns by:
|
|
13
|
+
- Confusion matrix
|
|
14
|
+
- Confidence distribution
|
|
15
|
+
- Threshold sensitivity (if confidence scores available)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
name = "timing"
|
|
19
|
+
|
|
20
|
+
def analyze(
|
|
21
|
+
self, predictions: Sequence[bool], references: Sequence[bool], metadata: dict[str, Any]
|
|
22
|
+
) -> dict[str, Any]:
|
|
23
|
+
"""
|
|
24
|
+
Analyze timing predictions.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
predictions: List of predicted decisions (True = call tool)
|
|
28
|
+
references: List of reference decisions
|
|
29
|
+
metadata: Additional context (may include confidence scores)
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Dictionary with analysis results
|
|
33
|
+
"""
|
|
34
|
+
preds = np.array(predictions, dtype=bool)
|
|
35
|
+
refs = np.array(references, dtype=bool)
|
|
36
|
+
|
|
37
|
+
# Confusion matrix
|
|
38
|
+
true_positives = int(np.sum(preds & refs))
|
|
39
|
+
false_positives = int(np.sum(preds & ~refs))
|
|
40
|
+
false_negatives = int(np.sum(~preds & refs))
|
|
41
|
+
true_negatives = int(np.sum(~preds & ~refs))
|
|
42
|
+
|
|
43
|
+
confusion_matrix = {
|
|
44
|
+
"true_positives": true_positives,
|
|
45
|
+
"false_positives": false_positives,
|
|
46
|
+
"false_negatives": false_negatives,
|
|
47
|
+
"true_negatives": true_negatives,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Derived metrics
|
|
51
|
+
total = len(predictions)
|
|
52
|
+
positive_rate = (true_positives + false_positives) / total if total > 0 else 0.0
|
|
53
|
+
true_positive_rate = (
|
|
54
|
+
true_positives / (true_positives + false_negatives)
|
|
55
|
+
if (true_positives + false_negatives) > 0
|
|
56
|
+
else 0.0
|
|
57
|
+
)
|
|
58
|
+
false_positive_rate = (
|
|
59
|
+
false_positives / (false_positives + true_negatives)
|
|
60
|
+
if (false_positives + true_negatives) > 0
|
|
61
|
+
else 0.0
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Class distribution
|
|
65
|
+
class_distribution = {
|
|
66
|
+
"reference_positive_ratio": float(np.mean(refs)),
|
|
67
|
+
"predicted_positive_ratio": float(np.mean(preds)),
|
|
68
|
+
"reference_positive_count": int(np.sum(refs)),
|
|
69
|
+
"predicted_positive_count": int(np.sum(preds)),
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
# Confidence analysis if available
|
|
73
|
+
confidence_analysis = {}
|
|
74
|
+
if "confidences" in metadata:
|
|
75
|
+
confidences = np.array(metadata["confidences"])
|
|
76
|
+
|
|
77
|
+
# Confidence by correctness
|
|
78
|
+
correct_mask = preds == refs
|
|
79
|
+
confidence_analysis = {
|
|
80
|
+
"mean_confidence_correct": (
|
|
81
|
+
float(np.mean(confidences[correct_mask])) if np.any(correct_mask) else 0.0
|
|
82
|
+
),
|
|
83
|
+
"mean_confidence_incorrect": (
|
|
84
|
+
float(np.mean(confidences[~correct_mask])) if np.any(~correct_mask) else 0.0
|
|
85
|
+
),
|
|
86
|
+
"mean_confidence_overall": float(np.mean(confidences)),
|
|
87
|
+
"confidence_distribution": {
|
|
88
|
+
"bins": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
|
89
|
+
"counts": np.histogram(confidences, bins=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0])[
|
|
90
|
+
0
|
|
91
|
+
].tolist(),
|
|
92
|
+
},
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
# Threshold sensitivity analysis
|
|
96
|
+
thresholds = np.linspace(0.1, 0.9, 9)
|
|
97
|
+
threshold_metrics = []
|
|
98
|
+
|
|
99
|
+
for threshold in thresholds:
|
|
100
|
+
thresh_preds = confidences >= threshold
|
|
101
|
+
tp = int(np.sum(thresh_preds & refs))
|
|
102
|
+
fp = int(np.sum(thresh_preds & ~refs))
|
|
103
|
+
fn = int(np.sum(~thresh_preds & refs))
|
|
104
|
+
# tn = int(np.sum(~thresh_preds & ~refs)) # Not used in metrics
|
|
105
|
+
|
|
106
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
|
107
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
|
108
|
+
f1 = (
|
|
109
|
+
2 * (precision * recall) / (precision + recall)
|
|
110
|
+
if (precision + recall) > 0
|
|
111
|
+
else 0.0
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
threshold_metrics.append(
|
|
115
|
+
{
|
|
116
|
+
"threshold": float(threshold),
|
|
117
|
+
"precision": precision,
|
|
118
|
+
"recall": recall,
|
|
119
|
+
"f1": f1,
|
|
120
|
+
}
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
confidence_analysis["threshold_sensitivity"] = threshold_metrics
|
|
124
|
+
|
|
125
|
+
return {
|
|
126
|
+
"confusion_matrix": confusion_matrix,
|
|
127
|
+
"rates": {
|
|
128
|
+
"true_positive_rate": true_positive_rate,
|
|
129
|
+
"false_positive_rate": false_positive_rate,
|
|
130
|
+
"predicted_positive_rate": positive_rate,
|
|
131
|
+
},
|
|
132
|
+
"class_distribution": class_distribution,
|
|
133
|
+
"confidence_analysis": confidence_analysis,
|
|
134
|
+
"total_samples": total,
|
|
135
|
+
}
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Tool selection analyzer for detailed error analysis."""
|
|
2
|
+
|
|
3
|
+
from collections import Counter, defaultdict
|
|
4
|
+
from typing import Any, Optional, Sequence
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ToolSelectionAnalyzer:
|
|
8
|
+
"""
|
|
9
|
+
Analyzer for tool selection predictions.
|
|
10
|
+
|
|
11
|
+
Provides breakdowns by:
|
|
12
|
+
- Category coverage
|
|
13
|
+
- Error patterns (wrong tool categories)
|
|
14
|
+
- Tool popularity in predictions vs references
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
name = "tool_selection"
|
|
18
|
+
|
|
19
|
+
def __init__(self, tools_metadata: Optional[dict[str, Any]] = None):
|
|
20
|
+
"""
|
|
21
|
+
Initialize analyzer.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
tools_metadata: Optional metadata about tools (categories, etc.)
|
|
25
|
+
"""
|
|
26
|
+
self.tools_metadata = tools_metadata or {}
|
|
27
|
+
|
|
28
|
+
def analyze(
|
|
29
|
+
self,
|
|
30
|
+
predictions: Sequence[list[str]],
|
|
31
|
+
references: Sequence[list[str]],
|
|
32
|
+
metadata: dict[str, Any],
|
|
33
|
+
) -> dict[str, Any]:
|
|
34
|
+
"""
|
|
35
|
+
Analyze tool selection predictions.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
predictions: List of predicted tool ID lists
|
|
39
|
+
references: List of reference tool ID lists
|
|
40
|
+
metadata: Additional context
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Dictionary with analysis results
|
|
44
|
+
"""
|
|
45
|
+
# Tool frequency analysis
|
|
46
|
+
pred_tools: Counter[str] = Counter()
|
|
47
|
+
ref_tools: Counter[str] = Counter()
|
|
48
|
+
|
|
49
|
+
for pred_list in predictions:
|
|
50
|
+
pred_tools.update(pred_list)
|
|
51
|
+
for ref_list in references:
|
|
52
|
+
ref_tools.update(ref_list)
|
|
53
|
+
|
|
54
|
+
# Error pattern analysis
|
|
55
|
+
errors_by_type: dict[str, int] = defaultdict(int)
|
|
56
|
+
correct_selections = 0
|
|
57
|
+
total_predictions = 0
|
|
58
|
+
|
|
59
|
+
category_hits: dict[str, int] = defaultdict(int)
|
|
60
|
+
category_misses: dict[str, int] = defaultdict(int)
|
|
61
|
+
|
|
62
|
+
for pred, ref in zip(predictions, references):
|
|
63
|
+
ref_set = set(ref)
|
|
64
|
+
pred_set = set(pred)
|
|
65
|
+
|
|
66
|
+
# Count correct and incorrect
|
|
67
|
+
correct = pred_set & ref_set
|
|
68
|
+
incorrect = pred_set - ref_set
|
|
69
|
+
|
|
70
|
+
correct_selections += len(correct)
|
|
71
|
+
total_predictions += len(pred)
|
|
72
|
+
|
|
73
|
+
if len(correct) == 0:
|
|
74
|
+
errors_by_type["complete_miss"] += 1
|
|
75
|
+
elif len(incorrect) > 0:
|
|
76
|
+
errors_by_type["partial_correct"] += 1
|
|
77
|
+
else:
|
|
78
|
+
errors_by_type["all_correct"] += 1
|
|
79
|
+
|
|
80
|
+
# Category-level analysis if metadata available
|
|
81
|
+
for tool_id in ref:
|
|
82
|
+
category = self._get_category(tool_id)
|
|
83
|
+
if tool_id in pred_set:
|
|
84
|
+
category_hits[category] += 1
|
|
85
|
+
else:
|
|
86
|
+
category_misses[category] += 1
|
|
87
|
+
|
|
88
|
+
# Coverage statistics
|
|
89
|
+
pred_tool_set = set(pred_tools.keys())
|
|
90
|
+
ref_tool_set = set(ref_tools.keys())
|
|
91
|
+
|
|
92
|
+
return {
|
|
93
|
+
"error_patterns": dict(errors_by_type),
|
|
94
|
+
"tool_coverage": {
|
|
95
|
+
"predicted_tools": len(pred_tool_set),
|
|
96
|
+
"reference_tools": len(ref_tool_set),
|
|
97
|
+
"overlap": len(pred_tool_set & ref_tool_set),
|
|
98
|
+
"predicted_only": len(pred_tool_set - ref_tool_set),
|
|
99
|
+
"missed": len(ref_tool_set - pred_tool_set),
|
|
100
|
+
},
|
|
101
|
+
"tool_frequency": {
|
|
102
|
+
"top_predicted": pred_tools.most_common(10),
|
|
103
|
+
"top_reference": ref_tools.most_common(10),
|
|
104
|
+
},
|
|
105
|
+
"category_performance": {
|
|
106
|
+
"hits_by_category": dict(category_hits),
|
|
107
|
+
"misses_by_category": dict(category_misses),
|
|
108
|
+
},
|
|
109
|
+
"selection_accuracy": {
|
|
110
|
+
"correct_selections": correct_selections,
|
|
111
|
+
"total_predictions": total_predictions,
|
|
112
|
+
"accuracy": (
|
|
113
|
+
correct_selections / total_predictions if total_predictions > 0 else 0.0
|
|
114
|
+
),
|
|
115
|
+
},
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
def _get_category(self, tool_id: str) -> str:
|
|
119
|
+
"""Extract category from tool ID."""
|
|
120
|
+
# Tool ID format: {domain}_{category}_{number}
|
|
121
|
+
parts = tool_id.split("_")
|
|
122
|
+
if len(parts) >= 2:
|
|
123
|
+
return f"{parts[0]}_{parts[1]}"
|
|
124
|
+
return "unknown"
|