latch-eval-tools 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- latch_eval_tools/__init__.py +64 -0
- latch_eval_tools/answer_extraction.py +35 -0
- latch_eval_tools/cli/__init__.py +0 -0
- latch_eval_tools/cli/eval_lint.py +185 -0
- latch_eval_tools/eval_server.py +570 -0
- latch_eval_tools/faas_utils.py +13 -0
- latch_eval_tools/graders/__init__.py +40 -0
- latch_eval_tools/graders/base.py +29 -0
- latch_eval_tools/graders/distribution.py +102 -0
- latch_eval_tools/graders/label_set.py +75 -0
- latch_eval_tools/graders/marker_gene.py +317 -0
- latch_eval_tools/graders/multiple_choice.py +38 -0
- latch_eval_tools/graders/numeric.py +137 -0
- latch_eval_tools/graders/spatial.py +93 -0
- latch_eval_tools/harness/__init__.py +27 -0
- latch_eval_tools/harness/claudecode.py +212 -0
- latch_eval_tools/harness/minisweagent.py +265 -0
- latch_eval_tools/harness/plotsagent.py +156 -0
- latch_eval_tools/harness/runner.py +191 -0
- latch_eval_tools/harness/utils.py +191 -0
- latch_eval_tools/headless_eval_server.py +727 -0
- latch_eval_tools/linter/__init__.py +25 -0
- latch_eval_tools/linter/explanations.py +331 -0
- latch_eval_tools/linter/runner.py +146 -0
- latch_eval_tools/linter/schema.py +126 -0
- latch_eval_tools/linter/validators.py +595 -0
- latch_eval_tools/types.py +30 -0
- latch_eval_tools/wrapper_entrypoint.py +316 -0
- latch_eval_tools-0.1.0.dist-info/METADATA +118 -0
- latch_eval_tools-0.1.0.dist-info/RECORD +33 -0
- latch_eval_tools-0.1.0.dist-info/WHEEL +4 -0
- latch_eval_tools-0.1.0.dist-info/entry_points.txt +2 -0
- latch_eval_tools-0.1.0.dist-info/licenses/LICENSE +1 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from .base import BinaryGrader, GraderResult
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DistributionComparisonGrader(BinaryGrader):
|
|
5
|
+
def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
|
|
6
|
+
ground_truth = config.get("ground_truth", {})
|
|
7
|
+
tolerances = config.get("tolerances", {})
|
|
8
|
+
|
|
9
|
+
gt_total_cells = ground_truth.get("total_cells")
|
|
10
|
+
gt_distribution = ground_truth.get("cell_type_distribution", {})
|
|
11
|
+
|
|
12
|
+
total_cells_tolerance = tolerances.get("total_cells", {})
|
|
13
|
+
pct_tolerance_config = tolerances.get("cell_type_percentages", {})
|
|
14
|
+
pct_tolerance = pct_tolerance_config.get("value", 3.0)
|
|
15
|
+
|
|
16
|
+
if "cell_type_distribution" not in agent_answer:
|
|
17
|
+
return GraderResult(
|
|
18
|
+
passed=False,
|
|
19
|
+
metrics={},
|
|
20
|
+
reasoning="Agent answer missing required field: cell_type_distribution",
|
|
21
|
+
agent_answer=agent_answer
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
agent_total_cells = agent_answer.get("total_cells")
|
|
25
|
+
agent_distribution = agent_answer["cell_type_distribution"]
|
|
26
|
+
|
|
27
|
+
metrics = {}
|
|
28
|
+
all_pass = True
|
|
29
|
+
failures = []
|
|
30
|
+
|
|
31
|
+
if gt_total_cells is not None and agent_total_cells is not None:
|
|
32
|
+
total_cells_tol_value = total_cells_tolerance.get("value", 0)
|
|
33
|
+
total_cells_diff = abs(agent_total_cells - gt_total_cells)
|
|
34
|
+
total_cells_pass = total_cells_diff <= total_cells_tol_value
|
|
35
|
+
|
|
36
|
+
metrics["total_cells_actual"] = agent_total_cells
|
|
37
|
+
metrics["total_cells_expected"] = gt_total_cells
|
|
38
|
+
metrics["total_cells_diff"] = total_cells_diff
|
|
39
|
+
metrics["total_cells_pass"] = total_cells_pass
|
|
40
|
+
|
|
41
|
+
if not total_cells_pass:
|
|
42
|
+
all_pass = False
|
|
43
|
+
failures.append(f"total_cells: {agent_total_cells} vs {gt_total_cells} (diff: {total_cells_diff})")
|
|
44
|
+
|
|
45
|
+
distribution_failures = []
|
|
46
|
+
for cell_type, expected_pct in gt_distribution.items():
|
|
47
|
+
if cell_type not in agent_distribution:
|
|
48
|
+
all_pass = False
|
|
49
|
+
failures.append(f"Missing cell type: {cell_type}")
|
|
50
|
+
distribution_failures.append(cell_type)
|
|
51
|
+
metrics[f"{cell_type}_actual"] = None
|
|
52
|
+
metrics[f"{cell_type}_expected"] = expected_pct
|
|
53
|
+
metrics[f"{cell_type}_diff"] = None
|
|
54
|
+
metrics[f"{cell_type}_pass"] = False
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
actual_pct = agent_distribution[cell_type]
|
|
58
|
+
diff = abs(actual_pct - expected_pct)
|
|
59
|
+
within_tolerance = diff <= pct_tolerance
|
|
60
|
+
|
|
61
|
+
metrics[f"{cell_type}_actual"] = actual_pct
|
|
62
|
+
metrics[f"{cell_type}_expected"] = expected_pct
|
|
63
|
+
metrics[f"{cell_type}_diff"] = diff
|
|
64
|
+
metrics[f"{cell_type}_pass"] = within_tolerance
|
|
65
|
+
|
|
66
|
+
if not within_tolerance:
|
|
67
|
+
all_pass = False
|
|
68
|
+
failures.append(f"{cell_type}: {actual_pct:.2f}% vs {expected_pct:.2f}% (diff: {diff:.2f}%)")
|
|
69
|
+
distribution_failures.append(cell_type)
|
|
70
|
+
|
|
71
|
+
extra_types = set(agent_distribution.keys()) - set(gt_distribution.keys())
|
|
72
|
+
if extra_types:
|
|
73
|
+
metrics["extra_cell_types"] = sorted(list(extra_types))
|
|
74
|
+
|
|
75
|
+
lines = [
|
|
76
|
+
f"Distribution Comparison: {'PASS' if all_pass else 'FAIL'}",
|
|
77
|
+
"",
|
|
78
|
+
f"Cell type percentages (tolerance: +/-{pct_tolerance}%):"
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
for cell_type in sorted(gt_distribution.keys()):
|
|
82
|
+
expected = gt_distribution[cell_type]
|
|
83
|
+
if cell_type not in agent_distribution:
|
|
84
|
+
lines.append(f" x {cell_type}: MISSING vs {expected:.2f}%")
|
|
85
|
+
else:
|
|
86
|
+
actual = agent_distribution[cell_type]
|
|
87
|
+
diff = abs(actual - expected)
|
|
88
|
+
within_tol = diff <= pct_tolerance
|
|
89
|
+
check = "+" if within_tol else "x"
|
|
90
|
+
lines.append(f" {check} {cell_type}: {actual:.2f}% vs {expected:.2f}% (diff: {diff:.2f}%)")
|
|
91
|
+
|
|
92
|
+
if failures:
|
|
93
|
+
lines.extend(["", "Failures:"])
|
|
94
|
+
for failure in failures:
|
|
95
|
+
lines.append(f" - {failure}")
|
|
96
|
+
|
|
97
|
+
return GraderResult(
|
|
98
|
+
passed=all_pass,
|
|
99
|
+
metrics=metrics,
|
|
100
|
+
reasoning="\n".join(lines),
|
|
101
|
+
agent_answer=agent_answer
|
|
102
|
+
)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from .base import BinaryGrader, GraderResult
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LabelSetJaccardGrader(BinaryGrader):
|
|
5
|
+
def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
|
|
6
|
+
ground_truth_labels = set(config.get("ground_truth_labels", []))
|
|
7
|
+
scoring = config.get("scoring", {})
|
|
8
|
+
pass_threshold = scoring.get("pass_threshold", 0.90)
|
|
9
|
+
answer_field = config.get("answer_field", "cell_types_predicted")
|
|
10
|
+
|
|
11
|
+
if answer_field not in agent_answer:
|
|
12
|
+
return GraderResult(
|
|
13
|
+
passed=False,
|
|
14
|
+
metrics={},
|
|
15
|
+
reasoning=f"Agent answer missing required field: {answer_field}",
|
|
16
|
+
agent_answer=agent_answer
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
predicted_labels = set(agent_answer[answer_field])
|
|
20
|
+
|
|
21
|
+
intersection = ground_truth_labels & predicted_labels
|
|
22
|
+
union = ground_truth_labels | predicted_labels
|
|
23
|
+
|
|
24
|
+
jaccard_index = len(intersection) / len(union) if len(union) > 0 else 0.0
|
|
25
|
+
passed = jaccard_index >= pass_threshold
|
|
26
|
+
|
|
27
|
+
true_positives = intersection
|
|
28
|
+
false_positives = predicted_labels - ground_truth_labels
|
|
29
|
+
false_negatives = ground_truth_labels - predicted_labels
|
|
30
|
+
|
|
31
|
+
metrics = {
|
|
32
|
+
"jaccard_index": jaccard_index,
|
|
33
|
+
"pass_threshold": pass_threshold,
|
|
34
|
+
"answer_field": answer_field,
|
|
35
|
+
"true_positives": sorted(list(true_positives)),
|
|
36
|
+
"false_positives": sorted(list(false_positives)),
|
|
37
|
+
"false_negatives": sorted(list(false_negatives)),
|
|
38
|
+
"predicted_count": len(predicted_labels),
|
|
39
|
+
"ground_truth_count": len(ground_truth_labels),
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
lines = [
|
|
43
|
+
f"Label Set Comparison: {'PASS' if passed else 'FAIL'}",
|
|
44
|
+
"",
|
|
45
|
+
f" {'+'if passed else 'x'} Jaccard Index: {jaccard_index:.3f} (threshold: {pass_threshold:.3f})",
|
|
46
|
+
"",
|
|
47
|
+
f"Correct Labels ({len(true_positives)}):"
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
if true_positives:
|
|
51
|
+
for label in sorted(true_positives):
|
|
52
|
+
lines.append(f" + {label}")
|
|
53
|
+
else:
|
|
54
|
+
lines.append(" None")
|
|
55
|
+
|
|
56
|
+
lines.extend(["", f"Missing Labels ({len(false_negatives)}):"])
|
|
57
|
+
if false_negatives:
|
|
58
|
+
for label in sorted(false_negatives):
|
|
59
|
+
lines.append(f" - {label}")
|
|
60
|
+
else:
|
|
61
|
+
lines.append(" None")
|
|
62
|
+
|
|
63
|
+
lines.extend(["", f"Extra Labels ({len(false_positives)}):"])
|
|
64
|
+
if false_positives:
|
|
65
|
+
for label in sorted(false_positives):
|
|
66
|
+
lines.append(f" ? {label}")
|
|
67
|
+
else:
|
|
68
|
+
lines.append(" None")
|
|
69
|
+
|
|
70
|
+
return GraderResult(
|
|
71
|
+
passed=passed,
|
|
72
|
+
metrics=metrics,
|
|
73
|
+
reasoning="\n".join(lines),
|
|
74
|
+
agent_answer=agent_answer
|
|
75
|
+
)
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
from .base import BinaryGrader, GraderResult
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MarkerGenePrecisionRecallGrader(BinaryGrader):
|
|
5
|
+
def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
|
|
6
|
+
canonical_markers = config.get("canonical_markers", config.get("ground_truth_labels", []))
|
|
7
|
+
scoring = config.get("scoring", {})
|
|
8
|
+
thresholds = scoring.get("pass_thresholds", {})
|
|
9
|
+
answer_field = config.get("answer_field", "top_marker_genes")
|
|
10
|
+
|
|
11
|
+
if answer_field not in agent_answer:
|
|
12
|
+
for key in agent_answer.keys():
|
|
13
|
+
if isinstance(agent_answer[key], (list, dict)):
|
|
14
|
+
answer_field = key
|
|
15
|
+
break
|
|
16
|
+
|
|
17
|
+
if answer_field not in agent_answer:
|
|
18
|
+
return GraderResult(
|
|
19
|
+
passed=False,
|
|
20
|
+
metrics={},
|
|
21
|
+
reasoning=f"Agent answer missing required field. Available keys: {list(agent_answer.keys())}",
|
|
22
|
+
agent_answer=agent_answer
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
predicted = agent_answer[answer_field]
|
|
26
|
+
|
|
27
|
+
if isinstance(canonical_markers, dict) and isinstance(predicted, dict):
|
|
28
|
+
return self._evaluate_per_celltype(predicted, canonical_markers, thresholds, answer_field, agent_answer)
|
|
29
|
+
|
|
30
|
+
if isinstance(canonical_markers, dict) and answer_field in canonical_markers:
|
|
31
|
+
canonical_markers = canonical_markers[answer_field]
|
|
32
|
+
|
|
33
|
+
if not isinstance(predicted, list):
|
|
34
|
+
return GraderResult(
|
|
35
|
+
passed=False,
|
|
36
|
+
metrics={},
|
|
37
|
+
reasoning=f"{answer_field} must be a list, got {type(predicted).__name__}",
|
|
38
|
+
agent_answer=agent_answer
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if not isinstance(canonical_markers, list):
|
|
42
|
+
return GraderResult(
|
|
43
|
+
passed=False,
|
|
44
|
+
metrics={},
|
|
45
|
+
reasoning=f"canonical_markers must be a list for flat evaluation, got {type(canonical_markers).__name__}",
|
|
46
|
+
agent_answer=agent_answer
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
return self._evaluate_flat_list(predicted, canonical_markers, thresholds, answer_field, agent_answer)
|
|
50
|
+
|
|
51
|
+
def _evaluate_per_celltype(self, predicted: dict, canonical_markers: dict, thresholds: dict, answer_field: str, agent_answer: dict) -> GraderResult:
|
|
52
|
+
min_recall = thresholds.get("min_recall_per_celltype", thresholds.get("recall_at_k", 0.50))
|
|
53
|
+
min_celltypes_passing = thresholds.get("min_celltypes_passing", len(canonical_markers))
|
|
54
|
+
|
|
55
|
+
celltype_results = {}
|
|
56
|
+
celltypes_passing = 0
|
|
57
|
+
total_celltypes = len(canonical_markers)
|
|
58
|
+
|
|
59
|
+
for celltype, canonical_genes in canonical_markers.items():
|
|
60
|
+
predicted_genes = predicted.get(celltype, [])
|
|
61
|
+
if not isinstance(predicted_genes, list):
|
|
62
|
+
celltype_results[celltype] = {
|
|
63
|
+
"pass": False,
|
|
64
|
+
"recall": 0.0,
|
|
65
|
+
"error": f"Expected list, got {type(predicted_genes).__name__}"
|
|
66
|
+
}
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
predicted_genes = [str(g) for g in predicted_genes]
|
|
70
|
+
canonical_set = set(str(gene).lower() for gene in canonical_genes)
|
|
71
|
+
predicted_set = set(gene.lower() for gene in predicted_genes)
|
|
72
|
+
|
|
73
|
+
true_positives = canonical_set & predicted_set
|
|
74
|
+
false_negatives = canonical_set - predicted_set
|
|
75
|
+
|
|
76
|
+
recall = len(true_positives) / len(canonical_set) if len(canonical_set) > 0 else 1.0
|
|
77
|
+
celltype_pass = recall >= min_recall
|
|
78
|
+
|
|
79
|
+
if celltype_pass:
|
|
80
|
+
celltypes_passing += 1
|
|
81
|
+
|
|
82
|
+
celltype_results[celltype] = {
|
|
83
|
+
"pass": celltype_pass,
|
|
84
|
+
"recall": recall,
|
|
85
|
+
"num_predicted": len(predicted_genes),
|
|
86
|
+
"num_canonical": len(canonical_set),
|
|
87
|
+
"true_positives": sorted(true_positives),
|
|
88
|
+
"false_negatives": sorted(false_negatives),
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
passed = celltypes_passing >= min_celltypes_passing
|
|
92
|
+
|
|
93
|
+
metrics = {
|
|
94
|
+
"celltypes_passing": celltypes_passing,
|
|
95
|
+
"total_celltypes": total_celltypes,
|
|
96
|
+
"min_celltypes_passing": min_celltypes_passing,
|
|
97
|
+
"min_recall_per_celltype": min_recall,
|
|
98
|
+
"per_celltype": celltype_results,
|
|
99
|
+
"answer_field_used": answer_field,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
lines = [
|
|
103
|
+
f"Marker Gene Per-Celltype: {'PASS' if passed else 'FAIL'}",
|
|
104
|
+
f"Celltypes passing: {celltypes_passing}/{total_celltypes} (required: {min_celltypes_passing})",
|
|
105
|
+
""
|
|
106
|
+
]
|
|
107
|
+
for celltype, result in celltype_results.items():
|
|
108
|
+
check = "+" if result["pass"] else "x"
|
|
109
|
+
lines.append(f" {check} {celltype}: recall={result['recall']:.2f} (threshold: {min_recall:.2f})")
|
|
110
|
+
|
|
111
|
+
return GraderResult(
|
|
112
|
+
passed=passed,
|
|
113
|
+
metrics=metrics,
|
|
114
|
+
reasoning="\n".join(lines),
|
|
115
|
+
agent_answer=agent_answer
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def _evaluate_flat_list(self, predicted_genes: list, canonical_markers: list, thresholds: dict, answer_field: str, agent_answer: dict) -> GraderResult:
|
|
119
|
+
precision_threshold = thresholds.get("precision_at_k", 0.60)
|
|
120
|
+
recall_threshold = thresholds.get("recall_at_k", 0.50)
|
|
121
|
+
|
|
122
|
+
predicted_genes = [str(g) for g in predicted_genes]
|
|
123
|
+
k = len(predicted_genes)
|
|
124
|
+
|
|
125
|
+
canonical_set = set(str(gene).lower() for gene in canonical_markers)
|
|
126
|
+
predicted_set = set(gene.lower() for gene in predicted_genes)
|
|
127
|
+
|
|
128
|
+
true_positives = canonical_set & predicted_set
|
|
129
|
+
false_positives = predicted_set - canonical_set
|
|
130
|
+
false_negatives = canonical_set - predicted_set
|
|
131
|
+
|
|
132
|
+
precision_at_k = len(true_positives) / k if k > 0 else 0.0
|
|
133
|
+
recall_at_k = len(true_positives) / len(canonical_set) if len(canonical_set) > 0 else 0.0
|
|
134
|
+
|
|
135
|
+
precision_pass = precision_at_k >= precision_threshold
|
|
136
|
+
recall_pass = recall_at_k >= recall_threshold
|
|
137
|
+
passed = precision_pass and recall_pass
|
|
138
|
+
|
|
139
|
+
original_case_map = {gene.lower(): gene for gene in predicted_genes}
|
|
140
|
+
canonical_case_map = {str(gene).lower(): str(gene) for gene in canonical_markers}
|
|
141
|
+
|
|
142
|
+
true_positive_genes = [original_case_map.get(g, canonical_case_map.get(g, g)) for g in true_positives]
|
|
143
|
+
false_positive_genes = [original_case_map.get(g, g) for g in false_positives]
|
|
144
|
+
false_negative_genes = [canonical_case_map.get(g, g) for g in false_negatives]
|
|
145
|
+
|
|
146
|
+
metrics = {
|
|
147
|
+
"k": k,
|
|
148
|
+
"precision_at_k": precision_at_k,
|
|
149
|
+
"recall_at_k": recall_at_k,
|
|
150
|
+
"precision_threshold": precision_threshold,
|
|
151
|
+
"recall_threshold": recall_threshold,
|
|
152
|
+
"true_positives": sorted(true_positive_genes),
|
|
153
|
+
"false_positives": sorted(false_positive_genes),
|
|
154
|
+
"false_negatives": sorted(false_negative_genes),
|
|
155
|
+
"num_true_positives": len(true_positives),
|
|
156
|
+
"num_false_positives": len(false_positives),
|
|
157
|
+
"num_false_negatives": len(false_negatives),
|
|
158
|
+
"num_canonical_markers": len(canonical_set),
|
|
159
|
+
"precision_pass": precision_pass,
|
|
160
|
+
"recall_pass": recall_pass,
|
|
161
|
+
"answer_field_used": answer_field,
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
reasoning = self._format_reasoning(
|
|
165
|
+
k, precision_at_k, recall_at_k, precision_threshold, recall_threshold,
|
|
166
|
+
true_positive_genes, false_positive_genes, false_negative_genes,
|
|
167
|
+
precision_pass, recall_pass, passed, answer_field
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return GraderResult(
|
|
171
|
+
passed=passed,
|
|
172
|
+
metrics=metrics,
|
|
173
|
+
reasoning=reasoning,
|
|
174
|
+
agent_answer=agent_answer
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def _format_reasoning(self, k, precision, recall, precision_threshold, recall_threshold,
|
|
178
|
+
true_positives, false_positives, false_negatives,
|
|
179
|
+
precision_pass, recall_pass, passed, answer_field):
|
|
180
|
+
lines = [
|
|
181
|
+
f"Marker Gene Precision/Recall: {'PASS' if passed else 'FAIL'}",
|
|
182
|
+
f"Answer field: {answer_field}",
|
|
183
|
+
"",
|
|
184
|
+
f" {'+'if precision_pass else 'x'} Precision@{k}: {precision:.3f} (threshold: {precision_threshold:.3f})",
|
|
185
|
+
f" {'+'if recall_pass else 'x'} Recall@{k}: {recall:.3f} (threshold: {recall_threshold:.3f})",
|
|
186
|
+
"",
|
|
187
|
+
f"True Positives ({len(true_positives)}):"
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
if true_positives:
|
|
191
|
+
for gene in sorted(true_positives):
|
|
192
|
+
lines.append(f" + {gene}")
|
|
193
|
+
else:
|
|
194
|
+
lines.append(" None")
|
|
195
|
+
|
|
196
|
+
lines.extend(["", f"False Negatives ({len(false_negatives)}):"])
|
|
197
|
+
if false_negatives:
|
|
198
|
+
for gene in sorted(false_negatives):
|
|
199
|
+
lines.append(f" - {gene}")
|
|
200
|
+
else:
|
|
201
|
+
lines.append(" None")
|
|
202
|
+
|
|
203
|
+
if not passed:
|
|
204
|
+
lines.append("")
|
|
205
|
+
failures = []
|
|
206
|
+
if not precision_pass:
|
|
207
|
+
failures.append(f"Precision {precision:.3f} < {precision_threshold:.3f}")
|
|
208
|
+
if not recall_pass:
|
|
209
|
+
failures.append(f"Recall {recall:.3f} < {recall_threshold:.3f}")
|
|
210
|
+
lines.append(f"Failure: {'; '.join(failures)}")
|
|
211
|
+
|
|
212
|
+
return "\n".join(lines)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class MarkerGeneSeparationGrader(BinaryGrader):
|
|
216
|
+
def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
|
|
217
|
+
scoring = config.get("scoring", {})
|
|
218
|
+
thresholds = scoring.get("pass_thresholds", {})
|
|
219
|
+
mean_auroc_threshold = thresholds.get("mean_auroc", 0.85)
|
|
220
|
+
fraction_high_threshold = thresholds.get("fraction_high", 0.70)
|
|
221
|
+
per_gene_cutoff = thresholds.get("per_gene_cutoff", 0.80)
|
|
222
|
+
|
|
223
|
+
if "per_gene_stats" not in agent_answer:
|
|
224
|
+
return GraderResult(
|
|
225
|
+
passed=False,
|
|
226
|
+
metrics={},
|
|
227
|
+
reasoning="Agent answer missing required field: per_gene_stats",
|
|
228
|
+
agent_answer=agent_answer
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if "mean_auroc" not in agent_answer:
|
|
232
|
+
return GraderResult(
|
|
233
|
+
passed=False,
|
|
234
|
+
metrics={},
|
|
235
|
+
reasoning="Agent answer missing required field: mean_auroc",
|
|
236
|
+
agent_answer=agent_answer
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
per_gene_stats = agent_answer["per_gene_stats"]
|
|
240
|
+
agent_mean_auroc = agent_answer["mean_auroc"]
|
|
241
|
+
|
|
242
|
+
if not isinstance(per_gene_stats, list):
|
|
243
|
+
return GraderResult(
|
|
244
|
+
passed=False,
|
|
245
|
+
metrics={},
|
|
246
|
+
reasoning="per_gene_stats must be a list",
|
|
247
|
+
agent_answer=agent_answer
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
num_genes = len(per_gene_stats)
|
|
251
|
+
if num_genes == 0:
|
|
252
|
+
return GraderResult(
|
|
253
|
+
passed=False,
|
|
254
|
+
metrics={},
|
|
255
|
+
reasoning="per_gene_stats is empty",
|
|
256
|
+
agent_answer=agent_answer
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
gene_aurocs = {}
|
|
260
|
+
for stat in per_gene_stats:
|
|
261
|
+
if not isinstance(stat, dict) or "gene" not in stat or "auroc" not in stat:
|
|
262
|
+
return GraderResult(
|
|
263
|
+
passed=False,
|
|
264
|
+
metrics={},
|
|
265
|
+
reasoning="Each element in per_gene_stats must have 'gene' and 'auroc' fields",
|
|
266
|
+
agent_answer=agent_answer
|
|
267
|
+
)
|
|
268
|
+
gene_aurocs[stat["gene"]] = stat["auroc"]
|
|
269
|
+
|
|
270
|
+
computed_mean_auroc = sum(gene_aurocs.values()) / len(gene_aurocs)
|
|
271
|
+
|
|
272
|
+
high_auroc_genes = [gene for gene, auroc in gene_aurocs.items() if auroc >= per_gene_cutoff]
|
|
273
|
+
low_auroc_genes = [gene for gene, auroc in gene_aurocs.items() if auroc < per_gene_cutoff]
|
|
274
|
+
fraction_high = len(high_auroc_genes) / num_genes
|
|
275
|
+
|
|
276
|
+
mean_auroc_pass = agent_mean_auroc >= mean_auroc_threshold
|
|
277
|
+
fraction_high_pass = fraction_high >= fraction_high_threshold
|
|
278
|
+
passed = mean_auroc_pass and fraction_high_pass
|
|
279
|
+
|
|
280
|
+
metrics = {
|
|
281
|
+
"num_genes": num_genes,
|
|
282
|
+
"mean_auroc_agent": agent_mean_auroc,
|
|
283
|
+
"mean_auroc_computed": computed_mean_auroc,
|
|
284
|
+
"mean_auroc_threshold": mean_auroc_threshold,
|
|
285
|
+
"fraction_high": fraction_high,
|
|
286
|
+
"fraction_high_threshold": fraction_high_threshold,
|
|
287
|
+
"per_gene_cutoff": per_gene_cutoff,
|
|
288
|
+
"num_high_auroc_genes": len(high_auroc_genes),
|
|
289
|
+
"num_low_auroc_genes": len(low_auroc_genes),
|
|
290
|
+
"high_auroc_genes": sorted(high_auroc_genes),
|
|
291
|
+
"low_auroc_genes": sorted(low_auroc_genes),
|
|
292
|
+
"mean_auroc_pass": mean_auroc_pass,
|
|
293
|
+
"fraction_high_pass": fraction_high_pass,
|
|
294
|
+
"per_gene_aurocs": gene_aurocs,
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
lines = [
|
|
298
|
+
f"Marker Gene Separation: {'PASS' if passed else 'FAIL'}",
|
|
299
|
+
"",
|
|
300
|
+
f" {'+'if mean_auroc_pass else 'x'} Mean AUROC: {agent_mean_auroc:.3f} (threshold: {mean_auroc_threshold:.3f})",
|
|
301
|
+
f" {'+'if fraction_high_pass else 'x'} Fraction High (>={per_gene_cutoff:.2f}): {fraction_high:.3f} ({len(high_auroc_genes)}/{num_genes})",
|
|
302
|
+
]
|
|
303
|
+
|
|
304
|
+
if not passed:
|
|
305
|
+
failures = []
|
|
306
|
+
if not mean_auroc_pass:
|
|
307
|
+
failures.append(f"Mean AUROC {agent_mean_auroc:.3f} < {mean_auroc_threshold:.3f}")
|
|
308
|
+
if not fraction_high_pass:
|
|
309
|
+
failures.append(f"Fraction high {fraction_high:.3f} < {fraction_high_threshold:.3f}")
|
|
310
|
+
lines.append(f"\nFailure: {'; '.join(failures)}")
|
|
311
|
+
|
|
312
|
+
return GraderResult(
|
|
313
|
+
passed=passed,
|
|
314
|
+
metrics=metrics,
|
|
315
|
+
reasoning="\n".join(lines),
|
|
316
|
+
agent_answer=agent_answer
|
|
317
|
+
)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from .base import BinaryGrader, GraderResult
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MultipleChoiceGrader(BinaryGrader):
|
|
5
|
+
def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
|
|
6
|
+
if "correct_answers" in config:
|
|
7
|
+
correct_answers = [a.strip().upper() for a in config["correct_answers"]]
|
|
8
|
+
else:
|
|
9
|
+
correct_answers = [config.get("correct_answer", "").strip().upper()]
|
|
10
|
+
|
|
11
|
+
if "answer" not in agent_answer:
|
|
12
|
+
return GraderResult(
|
|
13
|
+
passed=False,
|
|
14
|
+
metrics={},
|
|
15
|
+
reasoning="Agent answer missing required field: answer",
|
|
16
|
+
agent_answer=agent_answer
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
agent_choice = str(agent_answer["answer"]).strip().upper()
|
|
20
|
+
passed = agent_choice in correct_answers
|
|
21
|
+
|
|
22
|
+
display_correct = correct_answers[0] if len(correct_answers) == 1 else correct_answers
|
|
23
|
+
metrics = {
|
|
24
|
+
"correct_answers": correct_answers,
|
|
25
|
+
"agent_answer": agent_choice,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
if passed:
|
|
29
|
+
reasoning = f"Multiple Choice: PASS\n\n + Agent answered: {agent_choice} (correct)"
|
|
30
|
+
else:
|
|
31
|
+
reasoning = f"Multiple Choice: FAIL\n\n x Agent answered: {agent_choice}\n Correct answer(s): {display_correct}"
|
|
32
|
+
|
|
33
|
+
return GraderResult(
|
|
34
|
+
passed=passed,
|
|
35
|
+
metrics=metrics,
|
|
36
|
+
reasoning=reasoning,
|
|
37
|
+
agent_answer=agent_answer
|
|
38
|
+
)
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from .base import BinaryGrader, GraderResult, get_nested_value
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class NumericToleranceGrader(BinaryGrader):
|
|
5
|
+
def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
|
|
6
|
+
ground_truth = config.get("ground_truth", {})
|
|
7
|
+
tolerances = config.get("tolerances", config.get("tolerance", {}))
|
|
8
|
+
|
|
9
|
+
metrics = {}
|
|
10
|
+
all_pass = True
|
|
11
|
+
failures = []
|
|
12
|
+
|
|
13
|
+
for field, expected_value in ground_truth.items():
|
|
14
|
+
actual_value, found = get_nested_value(agent_answer, field)
|
|
15
|
+
if not found:
|
|
16
|
+
all_pass = False
|
|
17
|
+
failures.append(f"Missing field: {field}")
|
|
18
|
+
continue
|
|
19
|
+
|
|
20
|
+
if isinstance(actual_value, str):
|
|
21
|
+
try:
|
|
22
|
+
actual_value = float(actual_value)
|
|
23
|
+
except ValueError:
|
|
24
|
+
all_pass = False
|
|
25
|
+
failures.append(f"{field}: cannot parse '{actual_value}' as number")
|
|
26
|
+
continue
|
|
27
|
+
|
|
28
|
+
if isinstance(actual_value, bool):
|
|
29
|
+
actual_value = int(actual_value)
|
|
30
|
+
|
|
31
|
+
if actual_value is None:
|
|
32
|
+
all_pass = False
|
|
33
|
+
failures.append(f"{field}: got null/None value")
|
|
34
|
+
metrics[f"{field}_actual"] = None
|
|
35
|
+
metrics[f"{field}_expected"] = expected_value
|
|
36
|
+
metrics[f"{field}_error"] = float('inf')
|
|
37
|
+
metrics[f"{field}_pass"] = False
|
|
38
|
+
continue
|
|
39
|
+
|
|
40
|
+
tolerance_config = tolerances.get(field, {"type": "absolute", "value": 0})
|
|
41
|
+
if isinstance(tolerances, dict) and "type" in tolerances and "value" not in tolerances.get(field, {}):
|
|
42
|
+
tolerance_config = tolerances
|
|
43
|
+
tolerance_type = tolerance_config.get("type", "absolute")
|
|
44
|
+
has_asymmetric = "lower" in tolerance_config and "upper" in tolerance_config
|
|
45
|
+
tolerance_value = tolerance_config.get("value", 0)
|
|
46
|
+
tolerance_lower = tolerance_config.get("lower", tolerance_value)
|
|
47
|
+
tolerance_upper = tolerance_config.get("upper", tolerance_value)
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
if tolerance_type == "absolute":
|
|
51
|
+
if has_asymmetric:
|
|
52
|
+
within_tolerance = (expected_value - tolerance_lower) <= actual_value <= (expected_value + tolerance_upper)
|
|
53
|
+
error = actual_value - expected_value
|
|
54
|
+
else:
|
|
55
|
+
within_tolerance = abs(actual_value - expected_value) <= tolerance_value
|
|
56
|
+
error = abs(actual_value - expected_value)
|
|
57
|
+
elif tolerance_type == "relative":
|
|
58
|
+
relative_error = abs(actual_value - expected_value) / abs(expected_value) if expected_value != 0 else float('inf')
|
|
59
|
+
within_tolerance = relative_error <= tolerance_value
|
|
60
|
+
error = relative_error
|
|
61
|
+
elif tolerance_type == "min":
|
|
62
|
+
threshold = tolerance_value
|
|
63
|
+
within_tolerance = actual_value >= threshold
|
|
64
|
+
error = threshold - actual_value if actual_value < threshold else 0
|
|
65
|
+
elif tolerance_type == "max":
|
|
66
|
+
threshold = tolerance_value
|
|
67
|
+
within_tolerance = actual_value <= threshold
|
|
68
|
+
error = actual_value - threshold if actual_value > threshold else 0
|
|
69
|
+
else:
|
|
70
|
+
within_tolerance = False
|
|
71
|
+
error = float('inf')
|
|
72
|
+
except TypeError:
|
|
73
|
+
all_pass = False
|
|
74
|
+
failures.append(f"{field}: invalid type {type(actual_value).__name__}, expected numeric")
|
|
75
|
+
metrics[f"{field}_actual"] = actual_value
|
|
76
|
+
metrics[f"{field}_expected"] = expected_value
|
|
77
|
+
metrics[f"{field}_error"] = float('inf')
|
|
78
|
+
metrics[f"{field}_pass"] = False
|
|
79
|
+
continue
|
|
80
|
+
|
|
81
|
+
metrics[f"{field}_actual"] = actual_value
|
|
82
|
+
metrics[f"{field}_expected"] = expected_value
|
|
83
|
+
metrics[f"{field}_error"] = error
|
|
84
|
+
metrics[f"{field}_pass"] = within_tolerance
|
|
85
|
+
|
|
86
|
+
if not within_tolerance:
|
|
87
|
+
all_pass = False
|
|
88
|
+
if tolerance_type == "min":
|
|
89
|
+
failures.append(f"{field}: {actual_value} (minimum required: {tolerance_value})")
|
|
90
|
+
elif tolerance_type == "max":
|
|
91
|
+
failures.append(f"{field}: {actual_value} (maximum allowed: {tolerance_value})")
|
|
92
|
+
elif has_asymmetric:
|
|
93
|
+
failures.append(f"{field}: {actual_value} vs {expected_value} (allowed: -{tolerance_lower}/+{tolerance_upper})")
|
|
94
|
+
else:
|
|
95
|
+
failures.append(f"{field}: {actual_value} vs {expected_value} (error: {error:.2f}, tolerance: {tolerance_value})")
|
|
96
|
+
|
|
97
|
+
reasoning = self._format_reasoning(ground_truth, tolerances, metrics, failures, all_pass)
|
|
98
|
+
|
|
99
|
+
return GraderResult(
|
|
100
|
+
passed=all_pass,
|
|
101
|
+
metrics=metrics,
|
|
102
|
+
reasoning=reasoning,
|
|
103
|
+
agent_answer=agent_answer
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def _format_reasoning(self, ground_truth, tolerances, metrics, failures, passed):
|
|
107
|
+
lines = [f"Numeric Tolerance Check: {'PASS' if passed else 'FAIL'}", ""]
|
|
108
|
+
|
|
109
|
+
for field in ground_truth.keys():
|
|
110
|
+
if f"{field}_actual" in metrics:
|
|
111
|
+
actual = metrics[f"{field}_actual"]
|
|
112
|
+
expected = metrics[f"{field}_expected"]
|
|
113
|
+
error = metrics[f"{field}_error"]
|
|
114
|
+
field_pass = metrics[f"{field}_pass"]
|
|
115
|
+
check = "+" if field_pass else "x"
|
|
116
|
+
tolerance_config = tolerances.get(field, {}) if isinstance(tolerances, dict) else {}
|
|
117
|
+
tolerance_type = tolerance_config.get("type", "absolute")
|
|
118
|
+
has_asymmetric = "lower" in tolerance_config and "upper" in tolerance_config
|
|
119
|
+
if tolerance_type == "min":
|
|
120
|
+
tol_val = tolerance_config.get("value", expected)
|
|
121
|
+
lines.append(f" {check} {field}: {actual} (minimum: {tol_val})")
|
|
122
|
+
elif tolerance_type == "max":
|
|
123
|
+
tol_val = tolerance_config.get("value", expected)
|
|
124
|
+
lines.append(f" {check} {field}: {actual} (maximum: {tol_val})")
|
|
125
|
+
elif has_asymmetric:
|
|
126
|
+
lower = tolerance_config["lower"]
|
|
127
|
+
upper = tolerance_config["upper"]
|
|
128
|
+
lines.append(f" {check} {field}: {actual} vs {expected} (allowed: -{lower}/+{upper})")
|
|
129
|
+
else:
|
|
130
|
+
lines.append(f" {check} {field}: {actual} vs {expected} (error: {error:.4f})")
|
|
131
|
+
|
|
132
|
+
if not passed and failures:
|
|
133
|
+
lines.extend(["", "Failures:"])
|
|
134
|
+
for failure in failures:
|
|
135
|
+
lines.append(f" - {failure}")
|
|
136
|
+
|
|
137
|
+
return "\n".join(lines)
|