latch-eval-tools 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. latch_eval_tools/__init__.py +64 -0
  2. latch_eval_tools/answer_extraction.py +35 -0
  3. latch_eval_tools/cli/__init__.py +0 -0
  4. latch_eval_tools/cli/eval_lint.py +185 -0
  5. latch_eval_tools/eval_server.py +570 -0
  6. latch_eval_tools/faas_utils.py +13 -0
  7. latch_eval_tools/graders/__init__.py +40 -0
  8. latch_eval_tools/graders/base.py +29 -0
  9. latch_eval_tools/graders/distribution.py +102 -0
  10. latch_eval_tools/graders/label_set.py +75 -0
  11. latch_eval_tools/graders/marker_gene.py +317 -0
  12. latch_eval_tools/graders/multiple_choice.py +38 -0
  13. latch_eval_tools/graders/numeric.py +137 -0
  14. latch_eval_tools/graders/spatial.py +93 -0
  15. latch_eval_tools/harness/__init__.py +27 -0
  16. latch_eval_tools/harness/claudecode.py +212 -0
  17. latch_eval_tools/harness/minisweagent.py +265 -0
  18. latch_eval_tools/harness/plotsagent.py +156 -0
  19. latch_eval_tools/harness/runner.py +191 -0
  20. latch_eval_tools/harness/utils.py +191 -0
  21. latch_eval_tools/headless_eval_server.py +727 -0
  22. latch_eval_tools/linter/__init__.py +25 -0
  23. latch_eval_tools/linter/explanations.py +331 -0
  24. latch_eval_tools/linter/runner.py +146 -0
  25. latch_eval_tools/linter/schema.py +126 -0
  26. latch_eval_tools/linter/validators.py +595 -0
  27. latch_eval_tools/types.py +30 -0
  28. latch_eval_tools/wrapper_entrypoint.py +316 -0
  29. latch_eval_tools-0.1.0.dist-info/METADATA +118 -0
  30. latch_eval_tools-0.1.0.dist-info/RECORD +33 -0
  31. latch_eval_tools-0.1.0.dist-info/WHEEL +4 -0
  32. latch_eval_tools-0.1.0.dist-info/entry_points.txt +2 -0
  33. latch_eval_tools-0.1.0.dist-info/licenses/LICENSE +1 -0
@@ -0,0 +1,102 @@
1
+ from .base import BinaryGrader, GraderResult
2
+
3
+
4
+ class DistributionComparisonGrader(BinaryGrader):
5
+ def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
6
+ ground_truth = config.get("ground_truth", {})
7
+ tolerances = config.get("tolerances", {})
8
+
9
+ gt_total_cells = ground_truth.get("total_cells")
10
+ gt_distribution = ground_truth.get("cell_type_distribution", {})
11
+
12
+ total_cells_tolerance = tolerances.get("total_cells", {})
13
+ pct_tolerance_config = tolerances.get("cell_type_percentages", {})
14
+ pct_tolerance = pct_tolerance_config.get("value", 3.0)
15
+
16
+ if "cell_type_distribution" not in agent_answer:
17
+ return GraderResult(
18
+ passed=False,
19
+ metrics={},
20
+ reasoning="Agent answer missing required field: cell_type_distribution",
21
+ agent_answer=agent_answer
22
+ )
23
+
24
+ agent_total_cells = agent_answer.get("total_cells")
25
+ agent_distribution = agent_answer["cell_type_distribution"]
26
+
27
+ metrics = {}
28
+ all_pass = True
29
+ failures = []
30
+
31
+ if gt_total_cells is not None and agent_total_cells is not None:
32
+ total_cells_tol_value = total_cells_tolerance.get("value", 0)
33
+ total_cells_diff = abs(agent_total_cells - gt_total_cells)
34
+ total_cells_pass = total_cells_diff <= total_cells_tol_value
35
+
36
+ metrics["total_cells_actual"] = agent_total_cells
37
+ metrics["total_cells_expected"] = gt_total_cells
38
+ metrics["total_cells_diff"] = total_cells_diff
39
+ metrics["total_cells_pass"] = total_cells_pass
40
+
41
+ if not total_cells_pass:
42
+ all_pass = False
43
+ failures.append(f"total_cells: {agent_total_cells} vs {gt_total_cells} (diff: {total_cells_diff})")
44
+
45
+ distribution_failures = []
46
+ for cell_type, expected_pct in gt_distribution.items():
47
+ if cell_type not in agent_distribution:
48
+ all_pass = False
49
+ failures.append(f"Missing cell type: {cell_type}")
50
+ distribution_failures.append(cell_type)
51
+ metrics[f"{cell_type}_actual"] = None
52
+ metrics[f"{cell_type}_expected"] = expected_pct
53
+ metrics[f"{cell_type}_diff"] = None
54
+ metrics[f"{cell_type}_pass"] = False
55
+ continue
56
+
57
+ actual_pct = agent_distribution[cell_type]
58
+ diff = abs(actual_pct - expected_pct)
59
+ within_tolerance = diff <= pct_tolerance
60
+
61
+ metrics[f"{cell_type}_actual"] = actual_pct
62
+ metrics[f"{cell_type}_expected"] = expected_pct
63
+ metrics[f"{cell_type}_diff"] = diff
64
+ metrics[f"{cell_type}_pass"] = within_tolerance
65
+
66
+ if not within_tolerance:
67
+ all_pass = False
68
+ failures.append(f"{cell_type}: {actual_pct:.2f}% vs {expected_pct:.2f}% (diff: {diff:.2f}%)")
69
+ distribution_failures.append(cell_type)
70
+
71
+ extra_types = set(agent_distribution.keys()) - set(gt_distribution.keys())
72
+ if extra_types:
73
+ metrics["extra_cell_types"] = sorted(list(extra_types))
74
+
75
+ lines = [
76
+ f"Distribution Comparison: {'PASS' if all_pass else 'FAIL'}",
77
+ "",
78
+ f"Cell type percentages (tolerance: +/-{pct_tolerance}%):"
79
+ ]
80
+
81
+ for cell_type in sorted(gt_distribution.keys()):
82
+ expected = gt_distribution[cell_type]
83
+ if cell_type not in agent_distribution:
84
+ lines.append(f" x {cell_type}: MISSING vs {expected:.2f}%")
85
+ else:
86
+ actual = agent_distribution[cell_type]
87
+ diff = abs(actual - expected)
88
+ within_tol = diff <= pct_tolerance
89
+ check = "+" if within_tol else "x"
90
+ lines.append(f" {check} {cell_type}: {actual:.2f}% vs {expected:.2f}% (diff: {diff:.2f}%)")
91
+
92
+ if failures:
93
+ lines.extend(["", "Failures:"])
94
+ for failure in failures:
95
+ lines.append(f" - {failure}")
96
+
97
+ return GraderResult(
98
+ passed=all_pass,
99
+ metrics=metrics,
100
+ reasoning="\n".join(lines),
101
+ agent_answer=agent_answer
102
+ )
@@ -0,0 +1,75 @@
1
+ from .base import BinaryGrader, GraderResult
2
+
3
+
4
+ class LabelSetJaccardGrader(BinaryGrader):
5
+ def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
6
+ ground_truth_labels = set(config.get("ground_truth_labels", []))
7
+ scoring = config.get("scoring", {})
8
+ pass_threshold = scoring.get("pass_threshold", 0.90)
9
+ answer_field = config.get("answer_field", "cell_types_predicted")
10
+
11
+ if answer_field not in agent_answer:
12
+ return GraderResult(
13
+ passed=False,
14
+ metrics={},
15
+ reasoning=f"Agent answer missing required field: {answer_field}",
16
+ agent_answer=agent_answer
17
+ )
18
+
19
+ predicted_labels = set(agent_answer[answer_field])
20
+
21
+ intersection = ground_truth_labels & predicted_labels
22
+ union = ground_truth_labels | predicted_labels
23
+
24
+ jaccard_index = len(intersection) / len(union) if len(union) > 0 else 0.0
25
+ passed = jaccard_index >= pass_threshold
26
+
27
+ true_positives = intersection
28
+ false_positives = predicted_labels - ground_truth_labels
29
+ false_negatives = ground_truth_labels - predicted_labels
30
+
31
+ metrics = {
32
+ "jaccard_index": jaccard_index,
33
+ "pass_threshold": pass_threshold,
34
+ "answer_field": answer_field,
35
+ "true_positives": sorted(list(true_positives)),
36
+ "false_positives": sorted(list(false_positives)),
37
+ "false_negatives": sorted(list(false_negatives)),
38
+ "predicted_count": len(predicted_labels),
39
+ "ground_truth_count": len(ground_truth_labels),
40
+ }
41
+
42
+ lines = [
43
+ f"Label Set Comparison: {'PASS' if passed else 'FAIL'}",
44
+ "",
45
+ f" {'+'if passed else 'x'} Jaccard Index: {jaccard_index:.3f} (threshold: {pass_threshold:.3f})",
46
+ "",
47
+ f"Correct Labels ({len(true_positives)}):"
48
+ ]
49
+
50
+ if true_positives:
51
+ for label in sorted(true_positives):
52
+ lines.append(f" + {label}")
53
+ else:
54
+ lines.append(" None")
55
+
56
+ lines.extend(["", f"Missing Labels ({len(false_negatives)}):"])
57
+ if false_negatives:
58
+ for label in sorted(false_negatives):
59
+ lines.append(f" - {label}")
60
+ else:
61
+ lines.append(" None")
62
+
63
+ lines.extend(["", f"Extra Labels ({len(false_positives)}):"])
64
+ if false_positives:
65
+ for label in sorted(false_positives):
66
+ lines.append(f" ? {label}")
67
+ else:
68
+ lines.append(" None")
69
+
70
+ return GraderResult(
71
+ passed=passed,
72
+ metrics=metrics,
73
+ reasoning="\n".join(lines),
74
+ agent_answer=agent_answer
75
+ )
@@ -0,0 +1,317 @@
1
+ from .base import BinaryGrader, GraderResult
2
+
3
+
4
+ class MarkerGenePrecisionRecallGrader(BinaryGrader):
5
+ def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
6
+ canonical_markers = config.get("canonical_markers", config.get("ground_truth_labels", []))
7
+ scoring = config.get("scoring", {})
8
+ thresholds = scoring.get("pass_thresholds", {})
9
+ answer_field = config.get("answer_field", "top_marker_genes")
10
+
11
+ if answer_field not in agent_answer:
12
+ for key in agent_answer.keys():
13
+ if isinstance(agent_answer[key], (list, dict)):
14
+ answer_field = key
15
+ break
16
+
17
+ if answer_field not in agent_answer:
18
+ return GraderResult(
19
+ passed=False,
20
+ metrics={},
21
+ reasoning=f"Agent answer missing required field. Available keys: {list(agent_answer.keys())}",
22
+ agent_answer=agent_answer
23
+ )
24
+
25
+ predicted = agent_answer[answer_field]
26
+
27
+ if isinstance(canonical_markers, dict) and isinstance(predicted, dict):
28
+ return self._evaluate_per_celltype(predicted, canonical_markers, thresholds, answer_field, agent_answer)
29
+
30
+ if isinstance(canonical_markers, dict) and answer_field in canonical_markers:
31
+ canonical_markers = canonical_markers[answer_field]
32
+
33
+ if not isinstance(predicted, list):
34
+ return GraderResult(
35
+ passed=False,
36
+ metrics={},
37
+ reasoning=f"{answer_field} must be a list, got {type(predicted).__name__}",
38
+ agent_answer=agent_answer
39
+ )
40
+
41
+ if not isinstance(canonical_markers, list):
42
+ return GraderResult(
43
+ passed=False,
44
+ metrics={},
45
+ reasoning=f"canonical_markers must be a list for flat evaluation, got {type(canonical_markers).__name__}",
46
+ agent_answer=agent_answer
47
+ )
48
+
49
+ return self._evaluate_flat_list(predicted, canonical_markers, thresholds, answer_field, agent_answer)
50
+
51
+ def _evaluate_per_celltype(self, predicted: dict, canonical_markers: dict, thresholds: dict, answer_field: str, agent_answer: dict) -> GraderResult:
52
+ min_recall = thresholds.get("min_recall_per_celltype", thresholds.get("recall_at_k", 0.50))
53
+ min_celltypes_passing = thresholds.get("min_celltypes_passing", len(canonical_markers))
54
+
55
+ celltype_results = {}
56
+ celltypes_passing = 0
57
+ total_celltypes = len(canonical_markers)
58
+
59
+ for celltype, canonical_genes in canonical_markers.items():
60
+ predicted_genes = predicted.get(celltype, [])
61
+ if not isinstance(predicted_genes, list):
62
+ celltype_results[celltype] = {
63
+ "pass": False,
64
+ "recall": 0.0,
65
+ "error": f"Expected list, got {type(predicted_genes).__name__}"
66
+ }
67
+ continue
68
+
69
+ predicted_genes = [str(g) for g in predicted_genes]
70
+ canonical_set = set(str(gene).lower() for gene in canonical_genes)
71
+ predicted_set = set(gene.lower() for gene in predicted_genes)
72
+
73
+ true_positives = canonical_set & predicted_set
74
+ false_negatives = canonical_set - predicted_set
75
+
76
+ recall = len(true_positives) / len(canonical_set) if len(canonical_set) > 0 else 1.0
77
+ celltype_pass = recall >= min_recall
78
+
79
+ if celltype_pass:
80
+ celltypes_passing += 1
81
+
82
+ celltype_results[celltype] = {
83
+ "pass": celltype_pass,
84
+ "recall": recall,
85
+ "num_predicted": len(predicted_genes),
86
+ "num_canonical": len(canonical_set),
87
+ "true_positives": sorted(true_positives),
88
+ "false_negatives": sorted(false_negatives),
89
+ }
90
+
91
+ passed = celltypes_passing >= min_celltypes_passing
92
+
93
+ metrics = {
94
+ "celltypes_passing": celltypes_passing,
95
+ "total_celltypes": total_celltypes,
96
+ "min_celltypes_passing": min_celltypes_passing,
97
+ "min_recall_per_celltype": min_recall,
98
+ "per_celltype": celltype_results,
99
+ "answer_field_used": answer_field,
100
+ }
101
+
102
+ lines = [
103
+ f"Marker Gene Per-Celltype: {'PASS' if passed else 'FAIL'}",
104
+ f"Celltypes passing: {celltypes_passing}/{total_celltypes} (required: {min_celltypes_passing})",
105
+ ""
106
+ ]
107
+ for celltype, result in celltype_results.items():
108
+ check = "+" if result["pass"] else "x"
109
+ lines.append(f" {check} {celltype}: recall={result['recall']:.2f} (threshold: {min_recall:.2f})")
110
+
111
+ return GraderResult(
112
+ passed=passed,
113
+ metrics=metrics,
114
+ reasoning="\n".join(lines),
115
+ agent_answer=agent_answer
116
+ )
117
+
118
+ def _evaluate_flat_list(self, predicted_genes: list, canonical_markers: list, thresholds: dict, answer_field: str, agent_answer: dict) -> GraderResult:
119
+ precision_threshold = thresholds.get("precision_at_k", 0.60)
120
+ recall_threshold = thresholds.get("recall_at_k", 0.50)
121
+
122
+ predicted_genes = [str(g) for g in predicted_genes]
123
+ k = len(predicted_genes)
124
+
125
+ canonical_set = set(str(gene).lower() for gene in canonical_markers)
126
+ predicted_set = set(gene.lower() for gene in predicted_genes)
127
+
128
+ true_positives = canonical_set & predicted_set
129
+ false_positives = predicted_set - canonical_set
130
+ false_negatives = canonical_set - predicted_set
131
+
132
+ precision_at_k = len(true_positives) / k if k > 0 else 0.0
133
+ recall_at_k = len(true_positives) / len(canonical_set) if len(canonical_set) > 0 else 0.0
134
+
135
+ precision_pass = precision_at_k >= precision_threshold
136
+ recall_pass = recall_at_k >= recall_threshold
137
+ passed = precision_pass and recall_pass
138
+
139
+ original_case_map = {gene.lower(): gene for gene in predicted_genes}
140
+ canonical_case_map = {str(gene).lower(): str(gene) for gene in canonical_markers}
141
+
142
+ true_positive_genes = [original_case_map.get(g, canonical_case_map.get(g, g)) for g in true_positives]
143
+ false_positive_genes = [original_case_map.get(g, g) for g in false_positives]
144
+ false_negative_genes = [canonical_case_map.get(g, g) for g in false_negatives]
145
+
146
+ metrics = {
147
+ "k": k,
148
+ "precision_at_k": precision_at_k,
149
+ "recall_at_k": recall_at_k,
150
+ "precision_threshold": precision_threshold,
151
+ "recall_threshold": recall_threshold,
152
+ "true_positives": sorted(true_positive_genes),
153
+ "false_positives": sorted(false_positive_genes),
154
+ "false_negatives": sorted(false_negative_genes),
155
+ "num_true_positives": len(true_positives),
156
+ "num_false_positives": len(false_positives),
157
+ "num_false_negatives": len(false_negatives),
158
+ "num_canonical_markers": len(canonical_set),
159
+ "precision_pass": precision_pass,
160
+ "recall_pass": recall_pass,
161
+ "answer_field_used": answer_field,
162
+ }
163
+
164
+ reasoning = self._format_reasoning(
165
+ k, precision_at_k, recall_at_k, precision_threshold, recall_threshold,
166
+ true_positive_genes, false_positive_genes, false_negative_genes,
167
+ precision_pass, recall_pass, passed, answer_field
168
+ )
169
+
170
+ return GraderResult(
171
+ passed=passed,
172
+ metrics=metrics,
173
+ reasoning=reasoning,
174
+ agent_answer=agent_answer
175
+ )
176
+
177
+ def _format_reasoning(self, k, precision, recall, precision_threshold, recall_threshold,
178
+ true_positives, false_positives, false_negatives,
179
+ precision_pass, recall_pass, passed, answer_field):
180
+ lines = [
181
+ f"Marker Gene Precision/Recall: {'PASS' if passed else 'FAIL'}",
182
+ f"Answer field: {answer_field}",
183
+ "",
184
+ f" {'+'if precision_pass else 'x'} Precision@{k}: {precision:.3f} (threshold: {precision_threshold:.3f})",
185
+ f" {'+'if recall_pass else 'x'} Recall@{k}: {recall:.3f} (threshold: {recall_threshold:.3f})",
186
+ "",
187
+ f"True Positives ({len(true_positives)}):"
188
+ ]
189
+
190
+ if true_positives:
191
+ for gene in sorted(true_positives):
192
+ lines.append(f" + {gene}")
193
+ else:
194
+ lines.append(" None")
195
+
196
+ lines.extend(["", f"False Negatives ({len(false_negatives)}):"])
197
+ if false_negatives:
198
+ for gene in sorted(false_negatives):
199
+ lines.append(f" - {gene}")
200
+ else:
201
+ lines.append(" None")
202
+
203
+ if not passed:
204
+ lines.append("")
205
+ failures = []
206
+ if not precision_pass:
207
+ failures.append(f"Precision {precision:.3f} < {precision_threshold:.3f}")
208
+ if not recall_pass:
209
+ failures.append(f"Recall {recall:.3f} < {recall_threshold:.3f}")
210
+ lines.append(f"Failure: {'; '.join(failures)}")
211
+
212
+ return "\n".join(lines)
213
+
214
+
215
+ class MarkerGeneSeparationGrader(BinaryGrader):
216
+ def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
217
+ scoring = config.get("scoring", {})
218
+ thresholds = scoring.get("pass_thresholds", {})
219
+ mean_auroc_threshold = thresholds.get("mean_auroc", 0.85)
220
+ fraction_high_threshold = thresholds.get("fraction_high", 0.70)
221
+ per_gene_cutoff = thresholds.get("per_gene_cutoff", 0.80)
222
+
223
+ if "per_gene_stats" not in agent_answer:
224
+ return GraderResult(
225
+ passed=False,
226
+ metrics={},
227
+ reasoning="Agent answer missing required field: per_gene_stats",
228
+ agent_answer=agent_answer
229
+ )
230
+
231
+ if "mean_auroc" not in agent_answer:
232
+ return GraderResult(
233
+ passed=False,
234
+ metrics={},
235
+ reasoning="Agent answer missing required field: mean_auroc",
236
+ agent_answer=agent_answer
237
+ )
238
+
239
+ per_gene_stats = agent_answer["per_gene_stats"]
240
+ agent_mean_auroc = agent_answer["mean_auroc"]
241
+
242
+ if not isinstance(per_gene_stats, list):
243
+ return GraderResult(
244
+ passed=False,
245
+ metrics={},
246
+ reasoning="per_gene_stats must be a list",
247
+ agent_answer=agent_answer
248
+ )
249
+
250
+ num_genes = len(per_gene_stats)
251
+ if num_genes == 0:
252
+ return GraderResult(
253
+ passed=False,
254
+ metrics={},
255
+ reasoning="per_gene_stats is empty",
256
+ agent_answer=agent_answer
257
+ )
258
+
259
+ gene_aurocs = {}
260
+ for stat in per_gene_stats:
261
+ if not isinstance(stat, dict) or "gene" not in stat or "auroc" not in stat:
262
+ return GraderResult(
263
+ passed=False,
264
+ metrics={},
265
+ reasoning="Each element in per_gene_stats must have 'gene' and 'auroc' fields",
266
+ agent_answer=agent_answer
267
+ )
268
+ gene_aurocs[stat["gene"]] = stat["auroc"]
269
+
270
+ computed_mean_auroc = sum(gene_aurocs.values()) / len(gene_aurocs)
271
+
272
+ high_auroc_genes = [gene for gene, auroc in gene_aurocs.items() if auroc >= per_gene_cutoff]
273
+ low_auroc_genes = [gene for gene, auroc in gene_aurocs.items() if auroc < per_gene_cutoff]
274
+ fraction_high = len(high_auroc_genes) / num_genes
275
+
276
+ mean_auroc_pass = agent_mean_auroc >= mean_auroc_threshold
277
+ fraction_high_pass = fraction_high >= fraction_high_threshold
278
+ passed = mean_auroc_pass and fraction_high_pass
279
+
280
+ metrics = {
281
+ "num_genes": num_genes,
282
+ "mean_auroc_agent": agent_mean_auroc,
283
+ "mean_auroc_computed": computed_mean_auroc,
284
+ "mean_auroc_threshold": mean_auroc_threshold,
285
+ "fraction_high": fraction_high,
286
+ "fraction_high_threshold": fraction_high_threshold,
287
+ "per_gene_cutoff": per_gene_cutoff,
288
+ "num_high_auroc_genes": len(high_auroc_genes),
289
+ "num_low_auroc_genes": len(low_auroc_genes),
290
+ "high_auroc_genes": sorted(high_auroc_genes),
291
+ "low_auroc_genes": sorted(low_auroc_genes),
292
+ "mean_auroc_pass": mean_auroc_pass,
293
+ "fraction_high_pass": fraction_high_pass,
294
+ "per_gene_aurocs": gene_aurocs,
295
+ }
296
+
297
+ lines = [
298
+ f"Marker Gene Separation: {'PASS' if passed else 'FAIL'}",
299
+ "",
300
+ f" {'+'if mean_auroc_pass else 'x'} Mean AUROC: {agent_mean_auroc:.3f} (threshold: {mean_auroc_threshold:.3f})",
301
+ f" {'+'if fraction_high_pass else 'x'} Fraction High (>={per_gene_cutoff:.2f}): {fraction_high:.3f} ({len(high_auroc_genes)}/{num_genes})",
302
+ ]
303
+
304
+ if not passed:
305
+ failures = []
306
+ if not mean_auroc_pass:
307
+ failures.append(f"Mean AUROC {agent_mean_auroc:.3f} < {mean_auroc_threshold:.3f}")
308
+ if not fraction_high_pass:
309
+ failures.append(f"Fraction high {fraction_high:.3f} < {fraction_high_threshold:.3f}")
310
+ lines.append(f"\nFailure: {'; '.join(failures)}")
311
+
312
+ return GraderResult(
313
+ passed=passed,
314
+ metrics=metrics,
315
+ reasoning="\n".join(lines),
316
+ agent_answer=agent_answer
317
+ )
@@ -0,0 +1,38 @@
1
+ from .base import BinaryGrader, GraderResult
2
+
3
+
4
+ class MultipleChoiceGrader(BinaryGrader):
5
+ def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
6
+ if "correct_answers" in config:
7
+ correct_answers = [a.strip().upper() for a in config["correct_answers"]]
8
+ else:
9
+ correct_answers = [config.get("correct_answer", "").strip().upper()]
10
+
11
+ if "answer" not in agent_answer:
12
+ return GraderResult(
13
+ passed=False,
14
+ metrics={},
15
+ reasoning="Agent answer missing required field: answer",
16
+ agent_answer=agent_answer
17
+ )
18
+
19
+ agent_choice = str(agent_answer["answer"]).strip().upper()
20
+ passed = agent_choice in correct_answers
21
+
22
+ display_correct = correct_answers[0] if len(correct_answers) == 1 else correct_answers
23
+ metrics = {
24
+ "correct_answers": correct_answers,
25
+ "agent_answer": agent_choice,
26
+ }
27
+
28
+ if passed:
29
+ reasoning = f"Multiple Choice: PASS\n\n + Agent answered: {agent_choice} (correct)"
30
+ else:
31
+ reasoning = f"Multiple Choice: FAIL\n\n x Agent answered: {agent_choice}\n Correct answer(s): {display_correct}"
32
+
33
+ return GraderResult(
34
+ passed=passed,
35
+ metrics=metrics,
36
+ reasoning=reasoning,
37
+ agent_answer=agent_answer
38
+ )
@@ -0,0 +1,137 @@
1
+ from .base import BinaryGrader, GraderResult, get_nested_value
2
+
3
+
4
+ class NumericToleranceGrader(BinaryGrader):
5
+ def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
6
+ ground_truth = config.get("ground_truth", {})
7
+ tolerances = config.get("tolerances", config.get("tolerance", {}))
8
+
9
+ metrics = {}
10
+ all_pass = True
11
+ failures = []
12
+
13
+ for field, expected_value in ground_truth.items():
14
+ actual_value, found = get_nested_value(agent_answer, field)
15
+ if not found:
16
+ all_pass = False
17
+ failures.append(f"Missing field: {field}")
18
+ continue
19
+
20
+ if isinstance(actual_value, str):
21
+ try:
22
+ actual_value = float(actual_value)
23
+ except ValueError:
24
+ all_pass = False
25
+ failures.append(f"{field}: cannot parse '{actual_value}' as number")
26
+ continue
27
+
28
+ if isinstance(actual_value, bool):
29
+ actual_value = int(actual_value)
30
+
31
+ if actual_value is None:
32
+ all_pass = False
33
+ failures.append(f"{field}: got null/None value")
34
+ metrics[f"{field}_actual"] = None
35
+ metrics[f"{field}_expected"] = expected_value
36
+ metrics[f"{field}_error"] = float('inf')
37
+ metrics[f"{field}_pass"] = False
38
+ continue
39
+
40
+ tolerance_config = tolerances.get(field, {"type": "absolute", "value": 0})
41
+ if isinstance(tolerances, dict) and "type" in tolerances and "value" not in tolerances.get(field, {}):
42
+ tolerance_config = tolerances
43
+ tolerance_type = tolerance_config.get("type", "absolute")
44
+ has_asymmetric = "lower" in tolerance_config and "upper" in tolerance_config
45
+ tolerance_value = tolerance_config.get("value", 0)
46
+ tolerance_lower = tolerance_config.get("lower", tolerance_value)
47
+ tolerance_upper = tolerance_config.get("upper", tolerance_value)
48
+
49
+ try:
50
+ if tolerance_type == "absolute":
51
+ if has_asymmetric:
52
+ within_tolerance = (expected_value - tolerance_lower) <= actual_value <= (expected_value + tolerance_upper)
53
+ error = actual_value - expected_value
54
+ else:
55
+ within_tolerance = abs(actual_value - expected_value) <= tolerance_value
56
+ error = abs(actual_value - expected_value)
57
+ elif tolerance_type == "relative":
58
+ relative_error = abs(actual_value - expected_value) / abs(expected_value) if expected_value != 0 else float('inf')
59
+ within_tolerance = relative_error <= tolerance_value
60
+ error = relative_error
61
+ elif tolerance_type == "min":
62
+ threshold = tolerance_value
63
+ within_tolerance = actual_value >= threshold
64
+ error = threshold - actual_value if actual_value < threshold else 0
65
+ elif tolerance_type == "max":
66
+ threshold = tolerance_value
67
+ within_tolerance = actual_value <= threshold
68
+ error = actual_value - threshold if actual_value > threshold else 0
69
+ else:
70
+ within_tolerance = False
71
+ error = float('inf')
72
+ except TypeError:
73
+ all_pass = False
74
+ failures.append(f"{field}: invalid type {type(actual_value).__name__}, expected numeric")
75
+ metrics[f"{field}_actual"] = actual_value
76
+ metrics[f"{field}_expected"] = expected_value
77
+ metrics[f"{field}_error"] = float('inf')
78
+ metrics[f"{field}_pass"] = False
79
+ continue
80
+
81
+ metrics[f"{field}_actual"] = actual_value
82
+ metrics[f"{field}_expected"] = expected_value
83
+ metrics[f"{field}_error"] = error
84
+ metrics[f"{field}_pass"] = within_tolerance
85
+
86
+ if not within_tolerance:
87
+ all_pass = False
88
+ if tolerance_type == "min":
89
+ failures.append(f"{field}: {actual_value} (minimum required: {tolerance_value})")
90
+ elif tolerance_type == "max":
91
+ failures.append(f"{field}: {actual_value} (maximum allowed: {tolerance_value})")
92
+ elif has_asymmetric:
93
+ failures.append(f"{field}: {actual_value} vs {expected_value} (allowed: -{tolerance_lower}/+{tolerance_upper})")
94
+ else:
95
+ failures.append(f"{field}: {actual_value} vs {expected_value} (error: {error:.2f}, tolerance: {tolerance_value})")
96
+
97
+ reasoning = self._format_reasoning(ground_truth, tolerances, metrics, failures, all_pass)
98
+
99
+ return GraderResult(
100
+ passed=all_pass,
101
+ metrics=metrics,
102
+ reasoning=reasoning,
103
+ agent_answer=agent_answer
104
+ )
105
+
106
+ def _format_reasoning(self, ground_truth, tolerances, metrics, failures, passed):
107
+ lines = [f"Numeric Tolerance Check: {'PASS' if passed else 'FAIL'}", ""]
108
+
109
+ for field in ground_truth.keys():
110
+ if f"{field}_actual" in metrics:
111
+ actual = metrics[f"{field}_actual"]
112
+ expected = metrics[f"{field}_expected"]
113
+ error = metrics[f"{field}_error"]
114
+ field_pass = metrics[f"{field}_pass"]
115
+ check = "+" if field_pass else "x"
116
+ tolerance_config = tolerances.get(field, {}) if isinstance(tolerances, dict) else {}
117
+ tolerance_type = tolerance_config.get("type", "absolute")
118
+ has_asymmetric = "lower" in tolerance_config and "upper" in tolerance_config
119
+ if tolerance_type == "min":
120
+ tol_val = tolerance_config.get("value", expected)
121
+ lines.append(f" {check} {field}: {actual} (minimum: {tol_val})")
122
+ elif tolerance_type == "max":
123
+ tol_val = tolerance_config.get("value", expected)
124
+ lines.append(f" {check} {field}: {actual} (maximum: {tol_val})")
125
+ elif has_asymmetric:
126
+ lower = tolerance_config["lower"]
127
+ upper = tolerance_config["upper"]
128
+ lines.append(f" {check} {field}: {actual} vs {expected} (allowed: -{lower}/+{upper})")
129
+ else:
130
+ lines.append(f" {check} {field}: {actual} vs {expected} (error: {error:.4f})")
131
+
132
+ if not passed and failures:
133
+ lines.extend(["", "Failures:"])
134
+ for failure in failures:
135
+ lines.append(f" - {failure}")
136
+
137
+ return "\n".join(lines)