themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Pass@k metric for code generation evaluation.
|
|
2
|
+
|
|
3
|
+
Pass@k measures functional correctness by executing k generated code samples
|
|
4
|
+
and checking if any of them pass the test cases.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
Chen et al. (2021). Evaluating Large Language Models Trained on Code.
|
|
8
|
+
(HumanEval paper)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import math
|
|
14
|
+
from typing import Any, Sequence
|
|
15
|
+
|
|
16
|
+
from themis.core.entities import MetricScore
|
|
17
|
+
from themis.interfaces import Metric
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def estimate_pass_at_k(n: int, c: int, k: int) -> float:
|
|
21
|
+
"""Estimate pass@k using unbiased estimator.
|
|
22
|
+
|
|
23
|
+
This is the standard estimator from the HumanEval paper.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
n: Total number of samples generated
|
|
27
|
+
c: Number of samples that passed
|
|
28
|
+
k: k value for pass@k
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Estimated pass@k probability
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> # Generated 10 samples, 3 passed, compute pass@1
|
|
35
|
+
>>> estimate_pass_at_k(n=10, c=3, k=1)
|
|
36
|
+
0.3
|
|
37
|
+
|
|
38
|
+
>>> # Generated 100 samples, 30 passed, compute pass@10
|
|
39
|
+
>>> estimate_pass_at_k(n=100, c=30, k=10)
|
|
40
|
+
0.8926
|
|
41
|
+
"""
|
|
42
|
+
if n - c < k:
|
|
43
|
+
return 1.0
|
|
44
|
+
|
|
45
|
+
# Unbiased estimator: 1 - C(n-c, k) / C(n, k)
|
|
46
|
+
# = 1 - product((n-c-i)/(n-i) for i in range(k))
|
|
47
|
+
result = 1.0
|
|
48
|
+
for i in range(k):
|
|
49
|
+
result *= (n - c - i) / (n - i)
|
|
50
|
+
|
|
51
|
+
return 1.0 - result
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class PassAtK(Metric):
|
|
55
|
+
"""Pass@k metric for code generation.
|
|
56
|
+
|
|
57
|
+
Pass@k measures the probability that at least one of k generated samples
|
|
58
|
+
passes all test cases. It's the standard metric for evaluating code
|
|
59
|
+
generation models like Codex, CodeGen, etc.
|
|
60
|
+
|
|
61
|
+
The metric requires:
|
|
62
|
+
- Multiple samples per problem (num_samples >= k)
|
|
63
|
+
- Test cases to execute against
|
|
64
|
+
- Safe code execution environment
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
name: Metric identifier ("pass_at_k")
|
|
68
|
+
k: Number of samples to consider
|
|
69
|
+
timeout: Maximum execution time per sample (seconds)
|
|
70
|
+
require_all_tests: Whether all tests must pass (vs any test)
|
|
71
|
+
|
|
72
|
+
Example:
|
|
73
|
+
>>> from themis.evaluation.metrics.code import PassAtK
|
|
74
|
+
>>> metric = PassAtK(k=1)
|
|
75
|
+
>>> score = metric.compute(
|
|
76
|
+
... prediction={
|
|
77
|
+
... "samples": ["def add(a, b): return a + b", ...],
|
|
78
|
+
... "test_results": [True, False, ...],
|
|
79
|
+
... },
|
|
80
|
+
... references=[]
|
|
81
|
+
... )
|
|
82
|
+
>>> print(f"Pass@1: {score.value:.2%}")
|
|
83
|
+
Pass@1: 30.00%
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
requires_reference = False # Uses test execution, not reference matching
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
k: int = 1,
|
|
91
|
+
timeout: float = 3.0,
|
|
92
|
+
require_all_tests: bool = True,
|
|
93
|
+
):
|
|
94
|
+
"""Initialize Pass@k metric.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
k: Number of samples for pass@k estimation
|
|
98
|
+
timeout: Maximum execution time per sample (seconds)
|
|
99
|
+
require_all_tests: Whether all test cases must pass (default: True)
|
|
100
|
+
"""
|
|
101
|
+
self.name = f"pass_at_{k}"
|
|
102
|
+
self.k = k
|
|
103
|
+
self.timeout = timeout
|
|
104
|
+
self.require_all_tests = require_all_tests
|
|
105
|
+
|
|
106
|
+
def compute(
|
|
107
|
+
self,
|
|
108
|
+
*,
|
|
109
|
+
prediction: Any,
|
|
110
|
+
references: Sequence[Any],
|
|
111
|
+
metadata: dict[str, Any] | None = None,
|
|
112
|
+
) -> MetricScore:
|
|
113
|
+
"""Compute Pass@k score.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
prediction: Dictionary containing:
|
|
117
|
+
- "samples": List of generated code samples
|
|
118
|
+
- "test_results": List of booleans (True if passed)
|
|
119
|
+
- "execution_errors": Optional list of error messages
|
|
120
|
+
references: Not used (test-based evaluation)
|
|
121
|
+
metadata: Optional metadata dict
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
MetricScore with estimated pass@k probability
|
|
125
|
+
|
|
126
|
+
Note:
|
|
127
|
+
The prediction should be prepared by ExecutionAccuracy metric
|
|
128
|
+
or similar execution framework.
|
|
129
|
+
"""
|
|
130
|
+
if not isinstance(prediction, dict):
|
|
131
|
+
return MetricScore(
|
|
132
|
+
metric_name=self.name,
|
|
133
|
+
value=0.0,
|
|
134
|
+
details={"error": "Prediction must be dict with samples and test_results"},
|
|
135
|
+
metadata=metadata or {},
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
samples = prediction.get("samples", [])
|
|
139
|
+
test_results = prediction.get("test_results", [])
|
|
140
|
+
|
|
141
|
+
if not samples or not test_results:
|
|
142
|
+
return MetricScore(
|
|
143
|
+
metric_name=self.name,
|
|
144
|
+
value=0.0,
|
|
145
|
+
details={
|
|
146
|
+
"error": "Missing samples or test_results",
|
|
147
|
+
"num_samples": len(samples),
|
|
148
|
+
"num_results": len(test_results),
|
|
149
|
+
},
|
|
150
|
+
metadata=metadata or {},
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Count number of samples and passes
|
|
154
|
+
n = len(test_results)
|
|
155
|
+
c = sum(1 for result in test_results if result)
|
|
156
|
+
|
|
157
|
+
# Estimate pass@k
|
|
158
|
+
if n < self.k:
|
|
159
|
+
# Not enough samples, use empirical rate
|
|
160
|
+
pass_at_k = c / n if n > 0 else 0.0
|
|
161
|
+
warning = f"Only {n} samples available for pass@{self.k}"
|
|
162
|
+
else:
|
|
163
|
+
pass_at_k = estimate_pass_at_k(n, c, self.k)
|
|
164
|
+
warning = None
|
|
165
|
+
|
|
166
|
+
return MetricScore(
|
|
167
|
+
metric_name=self.name,
|
|
168
|
+
value=pass_at_k,
|
|
169
|
+
details={
|
|
170
|
+
"k": self.k,
|
|
171
|
+
"n_samples": n,
|
|
172
|
+
"n_passed": c,
|
|
173
|
+
"pass_rate": c / n if n > 0 else 0.0,
|
|
174
|
+
"pass_at_k": pass_at_k,
|
|
175
|
+
"warning": warning,
|
|
176
|
+
},
|
|
177
|
+
metadata=metadata or {},
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
__all__ = ["PassAtK", "estimate_pass_at_k"]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Sequence
|
|
5
|
+
|
|
6
|
+
from themis.core import entities as core_entities
|
|
7
|
+
from themis.interfaces import Metric as MetricInterface
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class CompositeMetric(MetricInterface):
|
|
12
|
+
children: Sequence[MetricInterface]
|
|
13
|
+
|
|
14
|
+
def __post_init__(self) -> None:
|
|
15
|
+
self.name = "CompositeMetric"
|
|
16
|
+
self.requires_reference = any(
|
|
17
|
+
getattr(child, "requires_reference", True) for child in self.children
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
def compute(
|
|
21
|
+
self,
|
|
22
|
+
*,
|
|
23
|
+
prediction: Any,
|
|
24
|
+
references: Sequence[Any],
|
|
25
|
+
metadata: dict[str, Any] | None = None,
|
|
26
|
+
) -> core_entities.MetricScore:
|
|
27
|
+
child_results = [
|
|
28
|
+
child.compute(
|
|
29
|
+
prediction=prediction, references=references, metadata=metadata
|
|
30
|
+
)
|
|
31
|
+
for child in self.children
|
|
32
|
+
]
|
|
33
|
+
if not child_results:
|
|
34
|
+
return core_entities.MetricScore(
|
|
35
|
+
metric_name=self.name,
|
|
36
|
+
value=0.0,
|
|
37
|
+
details={},
|
|
38
|
+
metadata=dict(metadata or {}),
|
|
39
|
+
)
|
|
40
|
+
value = sum(result.value for result in child_results) / len(child_results)
|
|
41
|
+
details = {result.metric_name: result.details for result in child_results}
|
|
42
|
+
return core_entities.MetricScore(
|
|
43
|
+
metric_name=self.name,
|
|
44
|
+
value=value,
|
|
45
|
+
details=details,
|
|
46
|
+
metadata=dict(metadata or {}),
|
|
47
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Sequence
|
|
5
|
+
|
|
6
|
+
from themis.core import entities as core_entities
|
|
7
|
+
from themis.interfaces import Metric as MetricInterface
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _normalize_text(value: str, case_sensitive: bool, strip_whitespace: bool) -> str:
|
|
11
|
+
if strip_whitespace:
|
|
12
|
+
value = value.strip()
|
|
13
|
+
if not case_sensitive:
|
|
14
|
+
value = value.lower()
|
|
15
|
+
return value
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ConsistencyMetric(MetricInterface):
|
|
20
|
+
case_sensitive: bool = False
|
|
21
|
+
strip_whitespace: bool = True
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
self.name = "Consistency"
|
|
25
|
+
self.requires_reference = False
|
|
26
|
+
|
|
27
|
+
def compute(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
prediction: Any,
|
|
31
|
+
references: Sequence[Any],
|
|
32
|
+
metadata: dict[str, Any] | None = None,
|
|
33
|
+
) -> core_entities.MetricScore:
|
|
34
|
+
md = dict(metadata or {})
|
|
35
|
+
|
|
36
|
+
outputs: list[str]
|
|
37
|
+
if isinstance(prediction, (list, tuple)):
|
|
38
|
+
outputs = [str(p) for p in prediction]
|
|
39
|
+
else:
|
|
40
|
+
outputs = [str(prediction)]
|
|
41
|
+
|
|
42
|
+
normalized = [
|
|
43
|
+
_normalize_text(text, self.case_sensitive, self.strip_whitespace)
|
|
44
|
+
for text in outputs
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
majority_correct = None
|
|
48
|
+
reference_text = None
|
|
49
|
+
if references:
|
|
50
|
+
reference_text = _normalize_text(
|
|
51
|
+
str(references[0]), self.case_sensitive, self.strip_whitespace
|
|
52
|
+
)
|
|
53
|
+
correct = [1.0 if out == reference_text else 0.0 for out in normalized]
|
|
54
|
+
majority_correct = sum(correct) / max(1, len(correct))
|
|
55
|
+
|
|
56
|
+
from collections import Counter
|
|
57
|
+
|
|
58
|
+
counter = Counter(normalized)
|
|
59
|
+
mode_count = max(counter.values()) if counter else 0
|
|
60
|
+
agreement = mode_count / max(1, len(normalized))
|
|
61
|
+
|
|
62
|
+
flips = 0
|
|
63
|
+
for i in range(1, len(normalized)):
|
|
64
|
+
if normalized[i] != normalized[i - 1]:
|
|
65
|
+
flips += 1
|
|
66
|
+
flip_rate = flips / max(1, len(normalized) - 1)
|
|
67
|
+
|
|
68
|
+
value = majority_correct if majority_correct is not None else agreement
|
|
69
|
+
|
|
70
|
+
return core_entities.MetricScore(
|
|
71
|
+
metric_name=self.name,
|
|
72
|
+
value=float(value),
|
|
73
|
+
details={
|
|
74
|
+
"agreement": agreement,
|
|
75
|
+
"flip_rate": flip_rate,
|
|
76
|
+
"outputs": outputs,
|
|
77
|
+
"reference": reference_text,
|
|
78
|
+
},
|
|
79
|
+
metadata=md,
|
|
80
|
+
)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Sequence
|
|
5
|
+
|
|
6
|
+
from themis.core import entities as core_entities
|
|
7
|
+
from themis.interfaces import Metric as MetricInterface
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _normalize_text(value: str, case_sensitive: bool, strip_whitespace: bool) -> str:
|
|
11
|
+
if strip_whitespace:
|
|
12
|
+
value = value.strip()
|
|
13
|
+
if not case_sensitive:
|
|
14
|
+
value = value.lower()
|
|
15
|
+
return value
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ExactMatch(MetricInterface):
|
|
20
|
+
case_sensitive: bool = False
|
|
21
|
+
strip_whitespace: bool = True
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
self.name = "ExactMatch"
|
|
25
|
+
|
|
26
|
+
def compute(
|
|
27
|
+
self,
|
|
28
|
+
*,
|
|
29
|
+
prediction: Any,
|
|
30
|
+
references: Sequence[Any],
|
|
31
|
+
metadata: dict[str, Any] | None = None,
|
|
32
|
+
) -> core_entities.MetricScore:
|
|
33
|
+
metadata = dict(metadata or {})
|
|
34
|
+
normalized_prediction = _normalize_text(
|
|
35
|
+
str(prediction), self.case_sensitive, self.strip_whitespace
|
|
36
|
+
)
|
|
37
|
+
matched_reference: str | None = None
|
|
38
|
+
for reference in references:
|
|
39
|
+
normalized_reference = _normalize_text(
|
|
40
|
+
str(reference), self.case_sensitive, self.strip_whitespace
|
|
41
|
+
)
|
|
42
|
+
if normalized_prediction == normalized_reference:
|
|
43
|
+
matched_reference = str(reference)
|
|
44
|
+
break
|
|
45
|
+
value = 1.0 if matched_reference is not None else 0.0
|
|
46
|
+
return core_entities.MetricScore(
|
|
47
|
+
metric_name=self.name,
|
|
48
|
+
value=value,
|
|
49
|
+
details={"matched_reference": matched_reference},
|
|
50
|
+
metadata=metadata,
|
|
51
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Sequence
|
|
5
|
+
|
|
6
|
+
from themis.core import entities as core_entities
|
|
7
|
+
from themis.interfaces import Metric as MetricInterface
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class LengthDifferenceTolerance(MetricInterface):
|
|
12
|
+
max_delta: int = 0
|
|
13
|
+
|
|
14
|
+
def __post_init__(self) -> None:
|
|
15
|
+
self.name = "LengthDifferenceTolerance"
|
|
16
|
+
|
|
17
|
+
def compute(
|
|
18
|
+
self,
|
|
19
|
+
*,
|
|
20
|
+
prediction: Any,
|
|
21
|
+
references: Sequence[Any],
|
|
22
|
+
metadata: dict[str, Any] | None = None,
|
|
23
|
+
) -> core_entities.MetricScore:
|
|
24
|
+
metadata = dict(metadata or {})
|
|
25
|
+
reference = str(references[0]) if references else ""
|
|
26
|
+
delta = abs(len(str(prediction)) - len(reference))
|
|
27
|
+
value = 1.0 if delta <= self.max_delta else 0.0
|
|
28
|
+
return core_entities.MetricScore(
|
|
29
|
+
metric_name=self.name,
|
|
30
|
+
value=value,
|
|
31
|
+
details={"delta": delta},
|
|
32
|
+
metadata=metadata,
|
|
33
|
+
)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Sequence
|
|
5
|
+
|
|
6
|
+
from themis.core import entities as core_entities
|
|
7
|
+
from themis.evaluation import math_verify_utils
|
|
8
|
+
from themis.interfaces import Metric as MetricInterface
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class MathVerifyAccuracy(MetricInterface):
|
|
13
|
+
"""Numeric equivalence check using math-verify."""
|
|
14
|
+
|
|
15
|
+
def __post_init__(self) -> None:
|
|
16
|
+
math_verify_utils.require_math_verify()
|
|
17
|
+
self.name = "MathVerifyAccuracy"
|
|
18
|
+
|
|
19
|
+
def compute(
|
|
20
|
+
self,
|
|
21
|
+
*,
|
|
22
|
+
prediction: Any,
|
|
23
|
+
references: Sequence[Any],
|
|
24
|
+
metadata: dict[str, Any] | None = None,
|
|
25
|
+
) -> core_entities.MetricScore:
|
|
26
|
+
math_verify_utils.require_math_verify()
|
|
27
|
+
metadata = dict(metadata or {})
|
|
28
|
+
prediction_expr = math_verify_utils.parse_expression(str(prediction))
|
|
29
|
+
passed = False
|
|
30
|
+
for reference in references:
|
|
31
|
+
reference_expr = math_verify_utils.parse_expression(str(reference))
|
|
32
|
+
if math_verify_utils.verify_expressions(reference_expr, prediction_expr):
|
|
33
|
+
passed = True
|
|
34
|
+
break
|
|
35
|
+
return core_entities.MetricScore(
|
|
36
|
+
metric_name=self.name,
|
|
37
|
+
value=1.0 if passed else 0.0,
|
|
38
|
+
details={"verified": passed},
|
|
39
|
+
metadata=metadata,
|
|
40
|
+
)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""NLP evaluation metrics.
|
|
2
|
+
|
|
3
|
+
This module provides standard NLP metrics for text generation evaluation:
|
|
4
|
+
- BLEU: Bilingual Evaluation Understudy for translation quality
|
|
5
|
+
- ROUGE: Recall-Oriented Understudy for Gisting Evaluation for summarization
|
|
6
|
+
- BERTScore: Contextual embeddings-based evaluation
|
|
7
|
+
- METEOR: Metric for Evaluation of Translation with Explicit ORdering
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from themis.evaluation.metrics.nlp.bleu import BLEU
|
|
11
|
+
from themis.evaluation.metrics.nlp.rouge import ROUGE, ROUGEVariant
|
|
12
|
+
from themis.evaluation.metrics.nlp.bertscore import BERTScore
|
|
13
|
+
from themis.evaluation.metrics.nlp.meteor import METEOR
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"BLEU",
|
|
17
|
+
"ROUGE",
|
|
18
|
+
"ROUGEVariant",
|
|
19
|
+
"BERTScore",
|
|
20
|
+
"METEOR",
|
|
21
|
+
]
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""BERTScore metric implementation.
|
|
2
|
+
|
|
3
|
+
BERTScore computes similarity using contextual embeddings from BERT-like models
|
|
4
|
+
instead of exact word matches.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
Zhang et al. (2020). BERTScore: Evaluating Text Generation with BERT.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any, Sequence
|
|
13
|
+
|
|
14
|
+
from themis.core.entities import MetricScore
|
|
15
|
+
from themis.interfaces import Metric
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BERTScore(Metric):
|
|
19
|
+
"""BERTScore metric using bert-score library.
|
|
20
|
+
|
|
21
|
+
BERTScore leverages contextual embeddings from pre-trained models (BERT, RoBERTa, etc.)
|
|
22
|
+
to compute semantic similarity between generated and reference texts. It's more
|
|
23
|
+
robust to paraphrasing than exact n-gram matching methods.
|
|
24
|
+
|
|
25
|
+
The metric computes token-level cosine similarity between embeddings and aggregates
|
|
26
|
+
using precision, recall, and F1.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
name: Metric identifier ("bertscore")
|
|
30
|
+
model_type: Pre-trained model to use for embeddings
|
|
31
|
+
lang: Language code for automatic model selection
|
|
32
|
+
rescale_with_baseline: Whether to rescale scores using baseline
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> from themis.evaluation.metrics.nlp import BERTScore
|
|
36
|
+
>>> metric = BERTScore(model_type="microsoft/deberta-xlarge-mnli")
|
|
37
|
+
>>> score = metric.compute(
|
|
38
|
+
... prediction="The cat sat on the mat",
|
|
39
|
+
... references=["A cat is sitting on a mat"]
|
|
40
|
+
... )
|
|
41
|
+
>>> print(f"BERTScore F1: {score.value:.4f}")
|
|
42
|
+
BERTScore F1: 0.9234
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
requires_reference = True
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
model_type: str | None = None,
|
|
50
|
+
lang: str | None = None,
|
|
51
|
+
rescale_with_baseline: bool = True,
|
|
52
|
+
device: str | None = None,
|
|
53
|
+
):
|
|
54
|
+
"""Initialize BERTScore metric.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model_type: Pre-trained model identifier. Popular choices:
|
|
58
|
+
- "microsoft/deberta-xlarge-mnli" (recommended, large)
|
|
59
|
+
- "microsoft/deberta-large-mnli" (good balance)
|
|
60
|
+
- "roberta-large" (fast, good quality)
|
|
61
|
+
- "bert-base-uncased" (fastest, lower quality)
|
|
62
|
+
lang: Language code (e.g., "en", "zh", "fr"). If provided,
|
|
63
|
+
automatically selects appropriate model.
|
|
64
|
+
rescale_with_baseline: Whether to rescale scores using baseline
|
|
65
|
+
(recommended for human correlation)
|
|
66
|
+
device: Device to use ("cuda", "cpu", or None for auto-detect)
|
|
67
|
+
"""
|
|
68
|
+
self.name = "bertscore"
|
|
69
|
+
self.model_type = model_type
|
|
70
|
+
self.lang = lang
|
|
71
|
+
self.rescale_with_baseline = rescale_with_baseline
|
|
72
|
+
self.device = device
|
|
73
|
+
|
|
74
|
+
# Lazy import bert-score (not required for all users)
|
|
75
|
+
try:
|
|
76
|
+
import bert_score
|
|
77
|
+
self._bert_score = bert_score
|
|
78
|
+
except ImportError:
|
|
79
|
+
raise ImportError(
|
|
80
|
+
"bert-score is required for BERTScore metric. "
|
|
81
|
+
"Install it with: pip install bert-score"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def compute(
|
|
85
|
+
self,
|
|
86
|
+
*,
|
|
87
|
+
prediction: Any,
|
|
88
|
+
references: Sequence[Any],
|
|
89
|
+
metadata: dict[str, Any] | None = None,
|
|
90
|
+
) -> MetricScore:
|
|
91
|
+
"""Compute BERTScore.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
prediction: Generated text (already extracted by pipeline)
|
|
95
|
+
references: List of reference texts
|
|
96
|
+
metadata: Optional metadata dict
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
MetricScore with BERTScore F1 and precision/recall details
|
|
100
|
+
"""
|
|
101
|
+
# Convert to strings
|
|
102
|
+
pred_str = str(prediction)
|
|
103
|
+
ref_strs = [str(ref) for ref in references]
|
|
104
|
+
|
|
105
|
+
# Compute BERTScore
|
|
106
|
+
# Note: bert_score.score expects lists of predictions and references
|
|
107
|
+
P, R, F1 = self._bert_score.score(
|
|
108
|
+
[pred_str] * len(ref_strs), # Repeat prediction for each reference
|
|
109
|
+
ref_strs,
|
|
110
|
+
model_type=self.model_type,
|
|
111
|
+
lang=self.lang,
|
|
112
|
+
rescale_with_baseline=self.rescale_with_baseline,
|
|
113
|
+
device=self.device,
|
|
114
|
+
verbose=False,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Take maximum F1 across references
|
|
118
|
+
max_idx = F1.argmax().item()
|
|
119
|
+
max_precision = P[max_idx].item()
|
|
120
|
+
max_recall = R[max_idx].item()
|
|
121
|
+
max_f1 = F1[max_idx].item()
|
|
122
|
+
|
|
123
|
+
return MetricScore(
|
|
124
|
+
metric_name=self.name,
|
|
125
|
+
value=max_f1, # Use F1 as primary score
|
|
126
|
+
details={
|
|
127
|
+
"precision": max_precision,
|
|
128
|
+
"recall": max_recall,
|
|
129
|
+
"f1": max_f1,
|
|
130
|
+
"model_type": self.model_type or f"auto-{self.lang}",
|
|
131
|
+
"num_references": len(ref_strs),
|
|
132
|
+
"rescaled": self.rescale_with_baseline,
|
|
133
|
+
},
|
|
134
|
+
metadata=metadata or {},
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
__all__ = ["BERTScore"]
|