themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
"""Conditional and adaptive evaluation strategies.
|
|
2
|
+
|
|
3
|
+
This module provides evaluation components that adapt based on sample characteristics:
|
|
4
|
+
- ConditionalMetric: Only runs when condition is met
|
|
5
|
+
- AdaptiveEvaluationPipeline: Selects metrics based on sample metadata
|
|
6
|
+
- Metric selectors: Helper functions for common selection patterns
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
>>> # Only run math verification on math problems
|
|
10
|
+
>>> math_metric = ConditionalMetric(
|
|
11
|
+
... metric=MathVerifyAccuracy(),
|
|
12
|
+
... condition=lambda record: record.task.metadata.get("type") == "math"
|
|
13
|
+
... )
|
|
14
|
+
>>>
|
|
15
|
+
>>> # Adaptively select metrics based on task type
|
|
16
|
+
>>> def select_metrics(record):
|
|
17
|
+
... if record.task.metadata.get("type") == "math":
|
|
18
|
+
... return [ExactMatch(), MathVerifyAccuracy()]
|
|
19
|
+
... return [ExactMatch()]
|
|
20
|
+
>>>
|
|
21
|
+
>>> pipeline = AdaptiveEvaluationPipeline(
|
|
22
|
+
... extractor=extractor,
|
|
23
|
+
... metric_selector=select_metrics
|
|
24
|
+
... )
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
from collections import defaultdict
|
|
30
|
+
from dataclasses import dataclass
|
|
31
|
+
from typing import Any, Callable, Sequence
|
|
32
|
+
|
|
33
|
+
from themis.core import entities as core_entities
|
|
34
|
+
from themis.evaluation import pipeline, reports
|
|
35
|
+
from themis.interfaces import Metric
|
|
36
|
+
from themis.utils import tracing
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ConditionalMetric:
|
|
41
|
+
"""Metric that only runs when condition is met.
|
|
42
|
+
|
|
43
|
+
This wrapper allows you to conditionally apply metrics based on
|
|
44
|
+
record characteristics (metadata, task type, etc.).
|
|
45
|
+
|
|
46
|
+
Attributes:
|
|
47
|
+
metric: Wrapped metric
|
|
48
|
+
condition: Function that determines if metric should run
|
|
49
|
+
default_score: Score to return when condition is False
|
|
50
|
+
name: Optional override for metric name
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
>>> # Only run expensive metric on hard problems
|
|
54
|
+
>>> hard_metric = ConditionalMetric(
|
|
55
|
+
... metric=ExpensiveVerification(),
|
|
56
|
+
... condition=lambda r: r.task.metadata.get("difficulty") == "hard",
|
|
57
|
+
... default_score=0.0
|
|
58
|
+
... )
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
metric: Metric
|
|
62
|
+
condition: Callable[[core_entities.GenerationRecord], bool]
|
|
63
|
+
default_score: float = 0.0
|
|
64
|
+
name: str | None = None
|
|
65
|
+
|
|
66
|
+
def __post_init__(self):
|
|
67
|
+
if self.name is None:
|
|
68
|
+
self.name = f"conditional_{self.metric.name}"
|
|
69
|
+
|
|
70
|
+
def should_evaluate(self, record: core_entities.GenerationRecord) -> bool:
|
|
71
|
+
"""Check if metric should be evaluated for this record.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
record: Generation record
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
True if condition is met
|
|
78
|
+
"""
|
|
79
|
+
try:
|
|
80
|
+
return self.condition(record)
|
|
81
|
+
except Exception:
|
|
82
|
+
# If condition check fails, don't run metric
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
def compute(
|
|
86
|
+
self,
|
|
87
|
+
*,
|
|
88
|
+
prediction: Any,
|
|
89
|
+
references: Sequence[Any],
|
|
90
|
+
metadata: dict[str, Any] | None = None,
|
|
91
|
+
) -> core_entities.MetricScore:
|
|
92
|
+
"""Compute metric score.
|
|
93
|
+
|
|
94
|
+
Note: This method doesn't check the condition - it's assumed
|
|
95
|
+
the condition was already checked before calling compute.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
prediction: Predicted value
|
|
99
|
+
references: Reference values
|
|
100
|
+
metadata: Optional metadata
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Metric score
|
|
104
|
+
"""
|
|
105
|
+
return self.metric.compute(
|
|
106
|
+
prediction=prediction,
|
|
107
|
+
references=references,
|
|
108
|
+
metadata=metadata,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def compute_or_default(
|
|
112
|
+
self,
|
|
113
|
+
record: core_entities.GenerationRecord,
|
|
114
|
+
*,
|
|
115
|
+
prediction: Any,
|
|
116
|
+
references: Sequence[Any],
|
|
117
|
+
metadata: dict[str, Any] | None = None,
|
|
118
|
+
) -> core_entities.MetricScore:
|
|
119
|
+
"""Compute metric or return default score if condition not met.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
record: Generation record (for condition check)
|
|
123
|
+
prediction: Predicted value
|
|
124
|
+
references: Reference values
|
|
125
|
+
metadata: Optional metadata
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Metric score or default
|
|
129
|
+
"""
|
|
130
|
+
if self.should_evaluate(record):
|
|
131
|
+
return self.compute(
|
|
132
|
+
prediction=prediction,
|
|
133
|
+
references=references,
|
|
134
|
+
metadata=metadata,
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
return core_entities.MetricScore(
|
|
138
|
+
metric_name=self.name or self.metric.name,
|
|
139
|
+
value=self.default_score,
|
|
140
|
+
metadata={"skipped": True, "reason": "condition_not_met"},
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class AdaptiveEvaluationPipeline(pipeline.EvaluationPipeline):
|
|
145
|
+
"""Pipeline that selects metrics based on sample characteristics.
|
|
146
|
+
|
|
147
|
+
This pipeline allows different metrics to be applied to different
|
|
148
|
+
samples based on their metadata, task type, or other characteristics.
|
|
149
|
+
|
|
150
|
+
This is more efficient than ConditionalMetric when you have many
|
|
151
|
+
samples that can be grouped by their metric requirements.
|
|
152
|
+
|
|
153
|
+
Example:
|
|
154
|
+
>>> def select_metrics(record):
|
|
155
|
+
... task_type = record.task.metadata.get("type")
|
|
156
|
+
... if task_type == "math":
|
|
157
|
+
... return [ExactMatch(), MathVerifyAccuracy()]
|
|
158
|
+
... elif task_type == "code":
|
|
159
|
+
... return [CodeExecutionMetric()]
|
|
160
|
+
... return [ExactMatch()]
|
|
161
|
+
>>>
|
|
162
|
+
>>> pipeline = AdaptiveEvaluationPipeline(
|
|
163
|
+
... extractor=extractor,
|
|
164
|
+
... metric_selector=select_metrics
|
|
165
|
+
... )
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
def __init__(
|
|
169
|
+
self,
|
|
170
|
+
*,
|
|
171
|
+
extractor: Any,
|
|
172
|
+
metric_selector: Callable[[core_entities.GenerationRecord], list[Metric]],
|
|
173
|
+
**kwargs: Any,
|
|
174
|
+
):
|
|
175
|
+
"""Initialize adaptive pipeline.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
extractor: Extractor for all samples
|
|
179
|
+
metric_selector: Function that selects metrics for each record
|
|
180
|
+
**kwargs: Additional arguments passed to EvaluationPipeline
|
|
181
|
+
"""
|
|
182
|
+
# Initialize with empty metrics - we'll select them dynamically
|
|
183
|
+
super().__init__(extractor=extractor, metrics=[], **kwargs)
|
|
184
|
+
self._metric_selector = metric_selector
|
|
185
|
+
|
|
186
|
+
def evaluate(
|
|
187
|
+
self, records: Sequence[core_entities.GenerationRecord]
|
|
188
|
+
) -> pipeline.EvaluationReport:
|
|
189
|
+
"""Evaluate records with adaptive metric selection.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
records: Generation records to evaluate
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Evaluation report
|
|
196
|
+
"""
|
|
197
|
+
with tracing.span("adaptive_evaluation", num_records=len(records)):
|
|
198
|
+
# Group records by which metrics apply
|
|
199
|
+
metric_groups: dict[
|
|
200
|
+
tuple[str, ...], list[core_entities.GenerationRecord]
|
|
201
|
+
] = defaultdict(list)
|
|
202
|
+
record_metrics: dict[str, list[Metric]] = {}
|
|
203
|
+
|
|
204
|
+
# Phase 1: Group records by metric set
|
|
205
|
+
with tracing.span("group_by_metrics"):
|
|
206
|
+
for record in records:
|
|
207
|
+
selected_metrics = self._metric_selector(record)
|
|
208
|
+
metric_key = tuple(m.name for m in selected_metrics)
|
|
209
|
+
metric_groups[metric_key].append(record)
|
|
210
|
+
|
|
211
|
+
# Store mapping for later
|
|
212
|
+
sample_id = str(record.task.metadata.get("dataset_id", "unknown"))
|
|
213
|
+
record_metrics[sample_id] = selected_metrics
|
|
214
|
+
|
|
215
|
+
# Phase 2: Evaluate each group with appropriate metrics
|
|
216
|
+
all_eval_records = []
|
|
217
|
+
with tracing.span("evaluate_groups", num_groups=len(metric_groups)):
|
|
218
|
+
for metric_key, group_records in metric_groups.items():
|
|
219
|
+
if not group_records:
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
# Get metrics for this group
|
|
223
|
+
sample_id = str(
|
|
224
|
+
group_records[0].task.metadata.get("dataset_id", "unknown")
|
|
225
|
+
)
|
|
226
|
+
group_metrics = record_metrics.get(sample_id, [])
|
|
227
|
+
|
|
228
|
+
with tracing.span(
|
|
229
|
+
"evaluate_group",
|
|
230
|
+
metric_names=list(metric_key),
|
|
231
|
+
num_records=len(group_records),
|
|
232
|
+
):
|
|
233
|
+
# Create temporary pipeline for this group
|
|
234
|
+
temp_pipeline = pipeline.EvaluationPipeline(
|
|
235
|
+
extractor=self._extractor,
|
|
236
|
+
metrics=group_metrics,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Evaluate group
|
|
240
|
+
group_report = temp_pipeline.evaluate(group_records)
|
|
241
|
+
all_eval_records.extend(group_report.records)
|
|
242
|
+
|
|
243
|
+
# Phase 3: Aggregate all results
|
|
244
|
+
with tracing.span("aggregate_adaptive_results"):
|
|
245
|
+
# Collect all metric scores by metric name
|
|
246
|
+
metric_scores_by_name: dict[str, list[core_entities.MetricScore]] = (
|
|
247
|
+
defaultdict(list)
|
|
248
|
+
)
|
|
249
|
+
for eval_record in all_eval_records:
|
|
250
|
+
for score_record in eval_record.scores:
|
|
251
|
+
metric_scores_by_name[score_record.metric_name].append(
|
|
252
|
+
score_record
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Compute aggregates
|
|
256
|
+
metric_aggregates = {}
|
|
257
|
+
for metric_name, score_objs in metric_scores_by_name.items():
|
|
258
|
+
if score_objs:
|
|
259
|
+
metric_aggregates[metric_name] = (
|
|
260
|
+
reports.MetricAggregate.from_scores(
|
|
261
|
+
name=metric_name,
|
|
262
|
+
scores=score_objs,
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return reports.EvaluationReport(
|
|
267
|
+
metrics=metric_aggregates,
|
|
268
|
+
failures=[], # No failures tracked in adaptive pipeline
|
|
269
|
+
records=all_eval_records,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# Helper functions for common metric selection patterns
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def select_by_metadata_field(
|
|
277
|
+
field: str, metric_map: dict[Any, list[Metric]], default: list[Metric] | None = None
|
|
278
|
+
) -> Callable[[core_entities.GenerationRecord], list[Metric]]:
|
|
279
|
+
"""Create selector that chooses metrics based on metadata field value.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
field: Metadata field to check
|
|
283
|
+
metric_map: Mapping from field value to metrics
|
|
284
|
+
default: Default metrics if field value not in map
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
Metric selector function
|
|
288
|
+
|
|
289
|
+
Example:
|
|
290
|
+
>>> selector = select_by_metadata_field(
|
|
291
|
+
... "type",
|
|
292
|
+
... {
|
|
293
|
+
... "math": [ExactMatch(), MathVerifyAccuracy()],
|
|
294
|
+
... "code": [CodeExecutionMetric()],
|
|
295
|
+
... },
|
|
296
|
+
... default=[ExactMatch()]
|
|
297
|
+
... )
|
|
298
|
+
"""
|
|
299
|
+
default_metrics = default or []
|
|
300
|
+
|
|
301
|
+
def selector(record: core_entities.GenerationRecord) -> list[Metric]:
|
|
302
|
+
value = record.task.metadata.get(field)
|
|
303
|
+
return metric_map.get(value, default_metrics)
|
|
304
|
+
|
|
305
|
+
return selector
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def select_by_difficulty(
|
|
309
|
+
easy_metrics: list[Metric],
|
|
310
|
+
medium_metrics: list[Metric],
|
|
311
|
+
hard_metrics: list[Metric],
|
|
312
|
+
difficulty_field: str = "difficulty",
|
|
313
|
+
) -> Callable[[core_entities.GenerationRecord], list[Metric]]:
|
|
314
|
+
"""Create selector that chooses metrics based on difficulty.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
easy_metrics: Metrics for easy problems
|
|
318
|
+
medium_metrics: Metrics for medium problems
|
|
319
|
+
hard_metrics: Metrics for hard problems
|
|
320
|
+
difficulty_field: Name of difficulty field in metadata
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Metric selector function
|
|
324
|
+
|
|
325
|
+
Example:
|
|
326
|
+
>>> selector = select_by_difficulty(
|
|
327
|
+
... easy_metrics=[ExactMatch()],
|
|
328
|
+
... medium_metrics=[ExactMatch(), PartialCredit()],
|
|
329
|
+
... hard_metrics=[ExactMatch(), PartialCredit(), ManualReview()]
|
|
330
|
+
... )
|
|
331
|
+
"""
|
|
332
|
+
return select_by_metadata_field(
|
|
333
|
+
difficulty_field,
|
|
334
|
+
{
|
|
335
|
+
"easy": easy_metrics,
|
|
336
|
+
"medium": medium_metrics,
|
|
337
|
+
"hard": hard_metrics,
|
|
338
|
+
},
|
|
339
|
+
default=medium_metrics,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def select_by_condition(
|
|
344
|
+
condition: Callable[[core_entities.GenerationRecord], bool],
|
|
345
|
+
metrics_if_true: list[Metric],
|
|
346
|
+
metrics_if_false: list[Metric],
|
|
347
|
+
) -> Callable[[core_entities.GenerationRecord], list[Metric]]:
|
|
348
|
+
"""Create selector based on arbitrary condition.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
condition: Function to determine which metrics to use
|
|
352
|
+
metrics_if_true: Metrics if condition is True
|
|
353
|
+
metrics_if_false: Metrics if condition is False
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
Metric selector function
|
|
357
|
+
|
|
358
|
+
Example:
|
|
359
|
+
>>> selector = select_by_condition(
|
|
360
|
+
... lambda r: len(r.output.text) > 1000,
|
|
361
|
+
... metrics_if_true=[SummaryMetrics()],
|
|
362
|
+
... metrics_if_false=[ExactMatch()]
|
|
363
|
+
... )
|
|
364
|
+
"""
|
|
365
|
+
|
|
366
|
+
def selector(record: core_entities.GenerationRecord) -> list[Metric]:
|
|
367
|
+
try:
|
|
368
|
+
if condition(record):
|
|
369
|
+
return metrics_if_true
|
|
370
|
+
else:
|
|
371
|
+
return metrics_if_false
|
|
372
|
+
except Exception:
|
|
373
|
+
# If condition fails, use false branch
|
|
374
|
+
return metrics_if_false
|
|
375
|
+
|
|
376
|
+
return selector
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def combine_selectors(
|
|
380
|
+
*selectors: Callable[[core_entities.GenerationRecord], list[Metric]],
|
|
381
|
+
) -> Callable[[core_entities.GenerationRecord], list[Metric]]:
|
|
382
|
+
"""Combine multiple selectors (union of their metrics).
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
*selectors: Metric selectors to combine
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
Combined selector that returns union of all selected metrics
|
|
389
|
+
|
|
390
|
+
Example:
|
|
391
|
+
>>> selector = combine_selectors(
|
|
392
|
+
... select_by_type,
|
|
393
|
+
... select_by_difficulty,
|
|
394
|
+
... )
|
|
395
|
+
"""
|
|
396
|
+
|
|
397
|
+
def combined(record: core_entities.GenerationRecord) -> list[Metric]:
|
|
398
|
+
all_metrics = []
|
|
399
|
+
seen_names = set()
|
|
400
|
+
|
|
401
|
+
for selector in selectors:
|
|
402
|
+
selected = selector(record)
|
|
403
|
+
for metric in selected:
|
|
404
|
+
if metric.name not in seen_names:
|
|
405
|
+
all_metrics.append(metric)
|
|
406
|
+
seen_names.add(metric.name)
|
|
407
|
+
|
|
408
|
+
return all_metrics
|
|
409
|
+
|
|
410
|
+
return combined
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Output extractors used during evaluation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from .error_taxonomy_extractor import ErrorTaxonomyExtractor
|
|
6
|
+
from .exceptions import FieldExtractionError
|
|
7
|
+
from .identity_extractor import IdentityExtractor
|
|
8
|
+
from .json_field_extractor import JsonFieldExtractor
|
|
9
|
+
from .math_verify_extractor import MathVerifyExtractor
|
|
10
|
+
from .regex_extractor import RegexExtractor
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"FieldExtractionError",
|
|
14
|
+
"JsonFieldExtractor",
|
|
15
|
+
"RegexExtractor",
|
|
16
|
+
"IdentityExtractor",
|
|
17
|
+
"MathVerifyExtractor",
|
|
18
|
+
"ErrorTaxonomyExtractor",
|
|
19
|
+
]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ErrorTaxonomyExtractor:
|
|
8
|
+
"""
|
|
9
|
+
Lightweight error taxonomy extractor.
|
|
10
|
+
|
|
11
|
+
Heuristics:
|
|
12
|
+
- format_parse_failure: Unbalanced JSON-like braces suggest parsing intent but malformed format
|
|
13
|
+
- arithmetic_slip: Simple arithmetic expression like "X + Y = Z" evaluated incorrectly
|
|
14
|
+
- reasoning_gap: Final answer given without common justification keywords
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def extract(self, text: str) -> Dict[str, bool]:
|
|
18
|
+
labels = {
|
|
19
|
+
"format_parse_failure": False,
|
|
20
|
+
"arithmetic_slip": False,
|
|
21
|
+
"reasoning_gap": False,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
# format_parse_failure: JSON-like but malformed braces
|
|
25
|
+
if ("{" in text or "}" in text) and not self._balanced_braces(text):
|
|
26
|
+
labels["format_parse_failure"] = True
|
|
27
|
+
|
|
28
|
+
# arithmetic_slip: pattern "A op B = Z" mismatch
|
|
29
|
+
try:
|
|
30
|
+
if self._has_arithmetic_mismatch(text):
|
|
31
|
+
labels["arithmetic_slip"] = True
|
|
32
|
+
except Exception:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
# reasoning_gap: answer provided without justification keywords
|
|
36
|
+
lowered = text.lower()
|
|
37
|
+
has_answer_phrase = any(p in lowered for p in ("answer", "final", "therefore"))
|
|
38
|
+
has_justification = any(
|
|
39
|
+
k in lowered for k in ("because", "since", "thus", "therefore", "reason")
|
|
40
|
+
)
|
|
41
|
+
if has_answer_phrase and not has_justification:
|
|
42
|
+
labels["reasoning_gap"] = True
|
|
43
|
+
|
|
44
|
+
return labels
|
|
45
|
+
|
|
46
|
+
def _balanced_braces(self, text: str) -> bool:
|
|
47
|
+
count = 0
|
|
48
|
+
for ch in text:
|
|
49
|
+
if ch == "{":
|
|
50
|
+
count += 1
|
|
51
|
+
elif ch == "}":
|
|
52
|
+
count -= 1
|
|
53
|
+
if count < 0:
|
|
54
|
+
return False
|
|
55
|
+
return count == 0
|
|
56
|
+
|
|
57
|
+
def _has_arithmetic_mismatch(self, text: str) -> bool:
|
|
58
|
+
pattern = (
|
|
59
|
+
r"(-?\d+(?:\.\d+)?)\s*([+\-*/])\s*(-?\d+(?:\.\d+)?)\s*=\s*(-?\d+(?:\.\d+)?)"
|
|
60
|
+
)
|
|
61
|
+
m = re.search(pattern, text)
|
|
62
|
+
if not m:
|
|
63
|
+
return False
|
|
64
|
+
a, op, b, z = m.groups()
|
|
65
|
+
a = float(a)
|
|
66
|
+
b = float(b)
|
|
67
|
+
z = float(z)
|
|
68
|
+
if op == "+":
|
|
69
|
+
calc = a + b
|
|
70
|
+
elif op == "-":
|
|
71
|
+
calc = a - b
|
|
72
|
+
elif op == "*":
|
|
73
|
+
calc = a * b
|
|
74
|
+
elif op == "/":
|
|
75
|
+
if b == 0:
|
|
76
|
+
return False
|
|
77
|
+
calc = a / b
|
|
78
|
+
else:
|
|
79
|
+
return False
|
|
80
|
+
return abs(calc - z) > 1e-9
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""Identity (pass-through) extraction."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class IdentityExtractor:
|
|
10
|
+
"""Extractor that returns the raw output as-is.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
strip_whitespace: Whether to strip leading/trailing whitespace
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
strip_whitespace: bool = True
|
|
17
|
+
|
|
18
|
+
def extract(self, raw_output: str) -> str:
|
|
19
|
+
"""Return the raw output, optionally stripping whitespace.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
raw_output: Raw output from model
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Raw output (possibly stripped)
|
|
26
|
+
"""
|
|
27
|
+
if self.strip_whitespace:
|
|
28
|
+
return raw_output.strip()
|
|
29
|
+
return raw_output
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""JSON field extraction."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from .exceptions import FieldExtractionError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class JsonFieldExtractor:
|
|
14
|
+
"""Extracts a specific field from JSON output.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
field_path: Dot-separated path to the field (e.g., "answer" or "result.value")
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
field_path: str
|
|
21
|
+
|
|
22
|
+
def extract(self, raw_output: str) -> Any:
|
|
23
|
+
"""Extract the specified field from JSON output.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
raw_output: Raw JSON string from model
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Extracted field value
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
FieldExtractionError: If JSON is invalid or field is missing
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
payload = json.loads(raw_output)
|
|
36
|
+
except json.JSONDecodeError as exc: # pragma: no cover - defensive path
|
|
37
|
+
raise FieldExtractionError("Invalid JSON output") from exc
|
|
38
|
+
|
|
39
|
+
current = payload
|
|
40
|
+
for part in self.field_path.split("."):
|
|
41
|
+
if isinstance(current, dict) and part in current:
|
|
42
|
+
current = current[part]
|
|
43
|
+
else:
|
|
44
|
+
raise FieldExtractionError(f"Missing field '{self.field_path}'")
|
|
45
|
+
return current
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Math-verify extraction for mathematical expressions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
from .exceptions import FieldExtractionError
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class MathVerifyExtractor:
|
|
12
|
+
"""Extracts the final boxed answer using math-verify parsing.
|
|
13
|
+
|
|
14
|
+
This extractor uses the math-verify library to parse and normalize
|
|
15
|
+
mathematical expressions from LaTeX boxed notation.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def extract(self, raw_output: str) -> str:
|
|
19
|
+
"""Extract and parse boxed mathematical answer.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
raw_output: Raw output containing \\boxed{...} notation
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Parsed and normalized mathematical expression
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
FieldExtractionError: If math-verify parsing fails
|
|
29
|
+
"""
|
|
30
|
+
from themis.evaluation import math_verify_utils as mv_utils
|
|
31
|
+
|
|
32
|
+
candidate = mv_utils.extract_last_boxed(raw_output)
|
|
33
|
+
try:
|
|
34
|
+
parsed = mv_utils.parse_expression(candidate)
|
|
35
|
+
except Exception as exc: # pragma: no cover - parse failure
|
|
36
|
+
raise FieldExtractionError("math-verify parsing failed") from exc
|
|
37
|
+
return str(parsed).strip()
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Regex-based extraction."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Dict
|
|
8
|
+
|
|
9
|
+
from .exceptions import FieldExtractionError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class RegexExtractor:
|
|
14
|
+
"""Extracts fields using regular expression patterns.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
pattern: Regular expression pattern with optional named groups
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
pattern: str
|
|
21
|
+
|
|
22
|
+
def __post_init__(self) -> None:
|
|
23
|
+
self._compiled = re.compile(self.pattern)
|
|
24
|
+
|
|
25
|
+
def extract(self, text: str) -> Dict[str, str]:
|
|
26
|
+
"""Extract fields from text using regex pattern.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
text: Text to extract from
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Dictionary of extracted groups (named or numbered)
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
FieldExtractionError: If pattern does not match
|
|
36
|
+
"""
|
|
37
|
+
match = self._compiled.search(text)
|
|
38
|
+
if not match:
|
|
39
|
+
raise FieldExtractionError("Regex did not match")
|
|
40
|
+
groups = match.groupdict()
|
|
41
|
+
if groups:
|
|
42
|
+
return {key: value.strip() for key, value in groups.items()}
|
|
43
|
+
return {str(index): value.strip() for index, value in enumerate(match.groups())}
|