themis-eval 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/main.py +427 -57
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/core/entities.py +23 -3
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/pipelines/standard_pipeline.py +68 -8
- themis/experiment/cache_manager.py +8 -3
- themis/experiment/export.py +110 -2
- themis/experiment/orchestrator.py +48 -6
- themis/experiment/storage.py +1313 -110
- themis/integrations/huggingface.py +12 -1
- themis/integrations/wandb.py +13 -1
- themis/interfaces/__init__.py +86 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.0.dist-info}/RECORD +40 -17
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.1.dist-info/METADATA +0 -758
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
"""Safe code execution for testing functional correctness.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for safely executing generated code against
|
|
4
|
+
test cases in a sandboxed environment.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import multiprocessing
|
|
10
|
+
import signal
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Any, Callable, Sequence
|
|
14
|
+
|
|
15
|
+
from themis.core.entities import MetricScore
|
|
16
|
+
from themis.interfaces import Metric
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ExecutionStatus(str, Enum):
|
|
20
|
+
"""Execution result status."""
|
|
21
|
+
|
|
22
|
+
PASSED = "passed"
|
|
23
|
+
FAILED = "failed"
|
|
24
|
+
TIMEOUT = "timeout"
|
|
25
|
+
ERROR = "error"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class ExecutionResult:
|
|
30
|
+
"""Result of code execution.
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
status: Execution status
|
|
34
|
+
passed: Whether all tests passed
|
|
35
|
+
output: Captured stdout/stderr
|
|
36
|
+
error: Error message if any
|
|
37
|
+
duration: Execution time in seconds
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
status: ExecutionStatus
|
|
41
|
+
passed: bool
|
|
42
|
+
output: str = ""
|
|
43
|
+
error: str | None = None
|
|
44
|
+
duration: float = 0.0
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ExecutionAccuracy(Metric):
|
|
48
|
+
"""Execute code and check against test cases.
|
|
49
|
+
|
|
50
|
+
This metric safely executes generated code in a restricted environment
|
|
51
|
+
and verifies correctness against provided test cases.
|
|
52
|
+
|
|
53
|
+
Security considerations:
|
|
54
|
+
- Executes in subprocess with timeout
|
|
55
|
+
- Restricted globals (no file I/O, network, etc.)
|
|
56
|
+
- Resource limits (memory, time)
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
name: Metric identifier ("execution_accuracy")
|
|
60
|
+
timeout: Maximum execution time per test (seconds)
|
|
61
|
+
max_memory_mb: Maximum memory usage (MB)
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
>>> from themis.evaluation.metrics.code import ExecutionAccuracy
|
|
65
|
+
>>> metric = ExecutionAccuracy(timeout=3.0)
|
|
66
|
+
>>>
|
|
67
|
+
>>> # Reference contains test cases
|
|
68
|
+
>>> test_cases = {
|
|
69
|
+
... "test_fn": test_function,
|
|
70
|
+
... "inputs": [(1, 2), (3, 4)],
|
|
71
|
+
... "expected": [3, 7]
|
|
72
|
+
... }
|
|
73
|
+
>>>
|
|
74
|
+
>>> score = metric.compute(
|
|
75
|
+
... prediction="def add(a, b): return a + b",
|
|
76
|
+
... references=[test_cases]
|
|
77
|
+
... )
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
requires_reference = True
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
timeout: float = 3.0,
|
|
85
|
+
max_memory_mb: int = 512,
|
|
86
|
+
):
|
|
87
|
+
"""Initialize execution metric.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
timeout: Maximum execution time per test (seconds)
|
|
91
|
+
max_memory_mb: Maximum memory usage (MB)
|
|
92
|
+
"""
|
|
93
|
+
self.name = "execution_accuracy"
|
|
94
|
+
self.timeout = timeout
|
|
95
|
+
self.max_memory_mb = max_memory_mb
|
|
96
|
+
|
|
97
|
+
def compute(
|
|
98
|
+
self,
|
|
99
|
+
*,
|
|
100
|
+
prediction: Any,
|
|
101
|
+
references: Sequence[Any],
|
|
102
|
+
metadata: dict[str, Any] | None = None,
|
|
103
|
+
) -> MetricScore:
|
|
104
|
+
"""Execute code and compute accuracy.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
prediction: Generated code to execute
|
|
108
|
+
references: List of test specifications
|
|
109
|
+
metadata: Optional metadata dict
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
MetricScore with execution accuracy
|
|
113
|
+
"""
|
|
114
|
+
code_str = str(prediction)
|
|
115
|
+
|
|
116
|
+
if not references:
|
|
117
|
+
return MetricScore(
|
|
118
|
+
metric_name=self.name,
|
|
119
|
+
value=0.0,
|
|
120
|
+
details={"error": "No test cases provided"},
|
|
121
|
+
metadata=metadata or {},
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Extract test cases from reference
|
|
125
|
+
test_spec = references[0]
|
|
126
|
+
if not isinstance(test_spec, dict):
|
|
127
|
+
return MetricScore(
|
|
128
|
+
metric_name=self.name,
|
|
129
|
+
value=0.0,
|
|
130
|
+
details={"error": "Test specification must be a dictionary"},
|
|
131
|
+
metadata=metadata or {},
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
test_inputs = test_spec.get("inputs", [])
|
|
135
|
+
expected_outputs = test_spec.get("expected", [])
|
|
136
|
+
test_fn_name = test_spec.get("function_name", "solution")
|
|
137
|
+
|
|
138
|
+
if len(test_inputs) != len(expected_outputs):
|
|
139
|
+
return MetricScore(
|
|
140
|
+
metric_name=self.name,
|
|
141
|
+
value=0.0,
|
|
142
|
+
details={"error": "Mismatch between inputs and expected outputs"},
|
|
143
|
+
metadata=metadata or {},
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Execute code and run tests
|
|
147
|
+
results = []
|
|
148
|
+
for test_input, expected in zip(test_inputs, expected_outputs):
|
|
149
|
+
result = self._execute_test(
|
|
150
|
+
code_str,
|
|
151
|
+
test_fn_name,
|
|
152
|
+
test_input,
|
|
153
|
+
expected,
|
|
154
|
+
)
|
|
155
|
+
results.append(result)
|
|
156
|
+
|
|
157
|
+
# Compute accuracy
|
|
158
|
+
passed = sum(1 for r in results if r.passed)
|
|
159
|
+
total = len(results)
|
|
160
|
+
accuracy = passed / total if total > 0 else 0.0
|
|
161
|
+
|
|
162
|
+
return MetricScore(
|
|
163
|
+
metric_name=self.name,
|
|
164
|
+
value=accuracy,
|
|
165
|
+
details={
|
|
166
|
+
"accuracy": accuracy,
|
|
167
|
+
"passed": passed,
|
|
168
|
+
"total": total,
|
|
169
|
+
"results": [
|
|
170
|
+
{
|
|
171
|
+
"status": r.status.value,
|
|
172
|
+
"passed": r.passed,
|
|
173
|
+
"error": r.error,
|
|
174
|
+
"duration": r.duration,
|
|
175
|
+
}
|
|
176
|
+
for r in results
|
|
177
|
+
],
|
|
178
|
+
},
|
|
179
|
+
metadata=metadata or {},
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def _execute_test(
|
|
183
|
+
self,
|
|
184
|
+
code: str,
|
|
185
|
+
function_name: str,
|
|
186
|
+
test_input: Any,
|
|
187
|
+
expected_output: Any,
|
|
188
|
+
) -> ExecutionResult:
|
|
189
|
+
"""Execute a single test case.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
code: Code to execute
|
|
193
|
+
function_name: Name of function to test
|
|
194
|
+
test_input: Input to pass to function
|
|
195
|
+
expected_output: Expected output
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
ExecutionResult with status and outcome
|
|
199
|
+
"""
|
|
200
|
+
import time
|
|
201
|
+
|
|
202
|
+
start_time = time.time()
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
# Create restricted globals (no file I/O, network, etc.)
|
|
206
|
+
restricted_globals = {
|
|
207
|
+
"__builtins__": {
|
|
208
|
+
"abs": abs,
|
|
209
|
+
"all": all,
|
|
210
|
+
"any": any,
|
|
211
|
+
"bool": bool,
|
|
212
|
+
"dict": dict,
|
|
213
|
+
"enumerate": enumerate,
|
|
214
|
+
"filter": filter,
|
|
215
|
+
"float": float,
|
|
216
|
+
"int": int,
|
|
217
|
+
"len": len,
|
|
218
|
+
"list": list,
|
|
219
|
+
"map": map,
|
|
220
|
+
"max": max,
|
|
221
|
+
"min": min,
|
|
222
|
+
"range": range,
|
|
223
|
+
"reversed": reversed,
|
|
224
|
+
"set": set,
|
|
225
|
+
"sorted": sorted,
|
|
226
|
+
"str": str,
|
|
227
|
+
"sum": sum,
|
|
228
|
+
"tuple": tuple,
|
|
229
|
+
"zip": zip,
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
# Execute code with timeout
|
|
234
|
+
local_vars = {}
|
|
235
|
+
exec(code, restricted_globals, local_vars)
|
|
236
|
+
|
|
237
|
+
# Get the function
|
|
238
|
+
if function_name not in local_vars:
|
|
239
|
+
return ExecutionResult(
|
|
240
|
+
status=ExecutionStatus.ERROR,
|
|
241
|
+
passed=False,
|
|
242
|
+
error=f"Function '{function_name}' not found",
|
|
243
|
+
duration=time.time() - start_time,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
func = local_vars[function_name]
|
|
247
|
+
|
|
248
|
+
# Run function with input
|
|
249
|
+
if isinstance(test_input, (list, tuple)):
|
|
250
|
+
actual_output = func(*test_input)
|
|
251
|
+
else:
|
|
252
|
+
actual_output = func(test_input)
|
|
253
|
+
|
|
254
|
+
# Check if output matches expected
|
|
255
|
+
passed = actual_output == expected_output
|
|
256
|
+
|
|
257
|
+
return ExecutionResult(
|
|
258
|
+
status=ExecutionStatus.PASSED if passed else ExecutionStatus.FAILED,
|
|
259
|
+
passed=passed,
|
|
260
|
+
output=str(actual_output),
|
|
261
|
+
duration=time.time() - start_time,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
except TimeoutError:
|
|
265
|
+
return ExecutionResult(
|
|
266
|
+
status=ExecutionStatus.TIMEOUT,
|
|
267
|
+
passed=False,
|
|
268
|
+
error=f"Execution timeout ({self.timeout}s)",
|
|
269
|
+
duration=self.timeout,
|
|
270
|
+
)
|
|
271
|
+
except Exception as e:
|
|
272
|
+
return ExecutionResult(
|
|
273
|
+
status=ExecutionStatus.ERROR,
|
|
274
|
+
passed=False,
|
|
275
|
+
error=str(e),
|
|
276
|
+
duration=time.time() - start_time,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
__all__ = ["ExecutionAccuracy", "ExecutionResult", "ExecutionStatus"]
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Pass@k metric for code generation evaluation.
|
|
2
|
+
|
|
3
|
+
Pass@k measures functional correctness by executing k generated code samples
|
|
4
|
+
and checking if any of them pass the test cases.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
Chen et al. (2021). Evaluating Large Language Models Trained on Code.
|
|
8
|
+
(HumanEval paper)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import math
|
|
14
|
+
from typing import Any, Sequence
|
|
15
|
+
|
|
16
|
+
from themis.core.entities import MetricScore
|
|
17
|
+
from themis.interfaces import Metric
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def estimate_pass_at_k(n: int, c: int, k: int) -> float:
|
|
21
|
+
"""Estimate pass@k using unbiased estimator.
|
|
22
|
+
|
|
23
|
+
This is the standard estimator from the HumanEval paper.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
n: Total number of samples generated
|
|
27
|
+
c: Number of samples that passed
|
|
28
|
+
k: k value for pass@k
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Estimated pass@k probability
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> # Generated 10 samples, 3 passed, compute pass@1
|
|
35
|
+
>>> estimate_pass_at_k(n=10, c=3, k=1)
|
|
36
|
+
0.3
|
|
37
|
+
|
|
38
|
+
>>> # Generated 100 samples, 30 passed, compute pass@10
|
|
39
|
+
>>> estimate_pass_at_k(n=100, c=30, k=10)
|
|
40
|
+
0.8926
|
|
41
|
+
"""
|
|
42
|
+
if n - c < k:
|
|
43
|
+
return 1.0
|
|
44
|
+
|
|
45
|
+
# Unbiased estimator: 1 - C(n-c, k) / C(n, k)
|
|
46
|
+
# = 1 - product((n-c-i)/(n-i) for i in range(k))
|
|
47
|
+
result = 1.0
|
|
48
|
+
for i in range(k):
|
|
49
|
+
result *= (n - c - i) / (n - i)
|
|
50
|
+
|
|
51
|
+
return 1.0 - result
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class PassAtK(Metric):
|
|
55
|
+
"""Pass@k metric for code generation.
|
|
56
|
+
|
|
57
|
+
Pass@k measures the probability that at least one of k generated samples
|
|
58
|
+
passes all test cases. It's the standard metric for evaluating code
|
|
59
|
+
generation models like Codex, CodeGen, etc.
|
|
60
|
+
|
|
61
|
+
The metric requires:
|
|
62
|
+
- Multiple samples per problem (num_samples >= k)
|
|
63
|
+
- Test cases to execute against
|
|
64
|
+
- Safe code execution environment
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
name: Metric identifier ("pass_at_k")
|
|
68
|
+
k: Number of samples to consider
|
|
69
|
+
timeout: Maximum execution time per sample (seconds)
|
|
70
|
+
require_all_tests: Whether all tests must pass (vs any test)
|
|
71
|
+
|
|
72
|
+
Example:
|
|
73
|
+
>>> from themis.evaluation.metrics.code import PassAtK
|
|
74
|
+
>>> metric = PassAtK(k=1)
|
|
75
|
+
>>> score = metric.compute(
|
|
76
|
+
... prediction={
|
|
77
|
+
... "samples": ["def add(a, b): return a + b", ...],
|
|
78
|
+
... "test_results": [True, False, ...],
|
|
79
|
+
... },
|
|
80
|
+
... references=[]
|
|
81
|
+
... )
|
|
82
|
+
>>> print(f"Pass@1: {score.value:.2%}")
|
|
83
|
+
Pass@1: 30.00%
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
requires_reference = False # Uses test execution, not reference matching
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
k: int = 1,
|
|
91
|
+
timeout: float = 3.0,
|
|
92
|
+
require_all_tests: bool = True,
|
|
93
|
+
):
|
|
94
|
+
"""Initialize Pass@k metric.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
k: Number of samples for pass@k estimation
|
|
98
|
+
timeout: Maximum execution time per sample (seconds)
|
|
99
|
+
require_all_tests: Whether all test cases must pass (default: True)
|
|
100
|
+
"""
|
|
101
|
+
self.name = f"pass_at_{k}"
|
|
102
|
+
self.k = k
|
|
103
|
+
self.timeout = timeout
|
|
104
|
+
self.require_all_tests = require_all_tests
|
|
105
|
+
|
|
106
|
+
def compute(
|
|
107
|
+
self,
|
|
108
|
+
*,
|
|
109
|
+
prediction: Any,
|
|
110
|
+
references: Sequence[Any],
|
|
111
|
+
metadata: dict[str, Any] | None = None,
|
|
112
|
+
) -> MetricScore:
|
|
113
|
+
"""Compute Pass@k score.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
prediction: Dictionary containing:
|
|
117
|
+
- "samples": List of generated code samples
|
|
118
|
+
- "test_results": List of booleans (True if passed)
|
|
119
|
+
- "execution_errors": Optional list of error messages
|
|
120
|
+
references: Not used (test-based evaluation)
|
|
121
|
+
metadata: Optional metadata dict
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
MetricScore with estimated pass@k probability
|
|
125
|
+
|
|
126
|
+
Note:
|
|
127
|
+
The prediction should be prepared by ExecutionAccuracy metric
|
|
128
|
+
or similar execution framework.
|
|
129
|
+
"""
|
|
130
|
+
if not isinstance(prediction, dict):
|
|
131
|
+
return MetricScore(
|
|
132
|
+
metric_name=self.name,
|
|
133
|
+
value=0.0,
|
|
134
|
+
details={"error": "Prediction must be dict with samples and test_results"},
|
|
135
|
+
metadata=metadata or {},
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
samples = prediction.get("samples", [])
|
|
139
|
+
test_results = prediction.get("test_results", [])
|
|
140
|
+
|
|
141
|
+
if not samples or not test_results:
|
|
142
|
+
return MetricScore(
|
|
143
|
+
metric_name=self.name,
|
|
144
|
+
value=0.0,
|
|
145
|
+
details={
|
|
146
|
+
"error": "Missing samples or test_results",
|
|
147
|
+
"num_samples": len(samples),
|
|
148
|
+
"num_results": len(test_results),
|
|
149
|
+
},
|
|
150
|
+
metadata=metadata or {},
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Count number of samples and passes
|
|
154
|
+
n = len(test_results)
|
|
155
|
+
c = sum(1 for result in test_results if result)
|
|
156
|
+
|
|
157
|
+
# Estimate pass@k
|
|
158
|
+
if n < self.k:
|
|
159
|
+
# Not enough samples, use empirical rate
|
|
160
|
+
pass_at_k = c / n if n > 0 else 0.0
|
|
161
|
+
warning = f"Only {n} samples available for pass@{self.k}"
|
|
162
|
+
else:
|
|
163
|
+
pass_at_k = estimate_pass_at_k(n, c, self.k)
|
|
164
|
+
warning = None
|
|
165
|
+
|
|
166
|
+
return MetricScore(
|
|
167
|
+
metric_name=self.name,
|
|
168
|
+
value=pass_at_k,
|
|
169
|
+
details={
|
|
170
|
+
"k": self.k,
|
|
171
|
+
"n_samples": n,
|
|
172
|
+
"n_passed": c,
|
|
173
|
+
"pass_rate": c / n if n > 0 else 0.0,
|
|
174
|
+
"pass_at_k": pass_at_k,
|
|
175
|
+
"warning": warning,
|
|
176
|
+
},
|
|
177
|
+
metadata=metadata or {},
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
__all__ = ["PassAtK", "estimate_pass_at_k"]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""NLP evaluation metrics.
|
|
2
|
+
|
|
3
|
+
This module provides standard NLP metrics for text generation evaluation:
|
|
4
|
+
- BLEU: Bilingual Evaluation Understudy for translation quality
|
|
5
|
+
- ROUGE: Recall-Oriented Understudy for Gisting Evaluation for summarization
|
|
6
|
+
- BERTScore: Contextual embeddings-based evaluation
|
|
7
|
+
- METEOR: Metric for Evaluation of Translation with Explicit ORdering
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from themis.evaluation.metrics.nlp.bleu import BLEU
|
|
11
|
+
from themis.evaluation.metrics.nlp.rouge import ROUGE, ROUGEVariant
|
|
12
|
+
from themis.evaluation.metrics.nlp.bertscore import BERTScore
|
|
13
|
+
from themis.evaluation.metrics.nlp.meteor import METEOR
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"BLEU",
|
|
17
|
+
"ROUGE",
|
|
18
|
+
"ROUGEVariant",
|
|
19
|
+
"BERTScore",
|
|
20
|
+
"METEOR",
|
|
21
|
+
]
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""BERTScore metric implementation.
|
|
2
|
+
|
|
3
|
+
BERTScore computes similarity using contextual embeddings from BERT-like models
|
|
4
|
+
instead of exact word matches.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
Zhang et al. (2020). BERTScore: Evaluating Text Generation with BERT.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any, Sequence
|
|
13
|
+
|
|
14
|
+
from themis.core.entities import MetricScore
|
|
15
|
+
from themis.interfaces import Metric
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BERTScore(Metric):
|
|
19
|
+
"""BERTScore metric using bert-score library.
|
|
20
|
+
|
|
21
|
+
BERTScore leverages contextual embeddings from pre-trained models (BERT, RoBERTa, etc.)
|
|
22
|
+
to compute semantic similarity between generated and reference texts. It's more
|
|
23
|
+
robust to paraphrasing than exact n-gram matching methods.
|
|
24
|
+
|
|
25
|
+
The metric computes token-level cosine similarity between embeddings and aggregates
|
|
26
|
+
using precision, recall, and F1.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
name: Metric identifier ("bertscore")
|
|
30
|
+
model_type: Pre-trained model to use for embeddings
|
|
31
|
+
lang: Language code for automatic model selection
|
|
32
|
+
rescale_with_baseline: Whether to rescale scores using baseline
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> from themis.evaluation.metrics.nlp import BERTScore
|
|
36
|
+
>>> metric = BERTScore(model_type="microsoft/deberta-xlarge-mnli")
|
|
37
|
+
>>> score = metric.compute(
|
|
38
|
+
... prediction="The cat sat on the mat",
|
|
39
|
+
... references=["A cat is sitting on a mat"]
|
|
40
|
+
... )
|
|
41
|
+
>>> print(f"BERTScore F1: {score.value:.4f}")
|
|
42
|
+
BERTScore F1: 0.9234
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
requires_reference = True
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
model_type: str | None = None,
|
|
50
|
+
lang: str | None = None,
|
|
51
|
+
rescale_with_baseline: bool = True,
|
|
52
|
+
device: str | None = None,
|
|
53
|
+
):
|
|
54
|
+
"""Initialize BERTScore metric.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model_type: Pre-trained model identifier. Popular choices:
|
|
58
|
+
- "microsoft/deberta-xlarge-mnli" (recommended, large)
|
|
59
|
+
- "microsoft/deberta-large-mnli" (good balance)
|
|
60
|
+
- "roberta-large" (fast, good quality)
|
|
61
|
+
- "bert-base-uncased" (fastest, lower quality)
|
|
62
|
+
lang: Language code (e.g., "en", "zh", "fr"). If provided,
|
|
63
|
+
automatically selects appropriate model.
|
|
64
|
+
rescale_with_baseline: Whether to rescale scores using baseline
|
|
65
|
+
(recommended for human correlation)
|
|
66
|
+
device: Device to use ("cuda", "cpu", or None for auto-detect)
|
|
67
|
+
"""
|
|
68
|
+
self.name = "bertscore"
|
|
69
|
+
self.model_type = model_type
|
|
70
|
+
self.lang = lang
|
|
71
|
+
self.rescale_with_baseline = rescale_with_baseline
|
|
72
|
+
self.device = device
|
|
73
|
+
|
|
74
|
+
# Lazy import bert-score (not required for all users)
|
|
75
|
+
try:
|
|
76
|
+
import bert_score
|
|
77
|
+
self._bert_score = bert_score
|
|
78
|
+
except ImportError:
|
|
79
|
+
raise ImportError(
|
|
80
|
+
"bert-score is required for BERTScore metric. "
|
|
81
|
+
"Install it with: pip install bert-score"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def compute(
|
|
85
|
+
self,
|
|
86
|
+
*,
|
|
87
|
+
prediction: Any,
|
|
88
|
+
references: Sequence[Any],
|
|
89
|
+
metadata: dict[str, Any] | None = None,
|
|
90
|
+
) -> MetricScore:
|
|
91
|
+
"""Compute BERTScore.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
prediction: Generated text (already extracted by pipeline)
|
|
95
|
+
references: List of reference texts
|
|
96
|
+
metadata: Optional metadata dict
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
MetricScore with BERTScore F1 and precision/recall details
|
|
100
|
+
"""
|
|
101
|
+
# Convert to strings
|
|
102
|
+
pred_str = str(prediction)
|
|
103
|
+
ref_strs = [str(ref) for ref in references]
|
|
104
|
+
|
|
105
|
+
# Compute BERTScore
|
|
106
|
+
# Note: bert_score.score expects lists of predictions and references
|
|
107
|
+
P, R, F1 = self._bert_score.score(
|
|
108
|
+
[pred_str] * len(ref_strs), # Repeat prediction for each reference
|
|
109
|
+
ref_strs,
|
|
110
|
+
model_type=self.model_type,
|
|
111
|
+
lang=self.lang,
|
|
112
|
+
rescale_with_baseline=self.rescale_with_baseline,
|
|
113
|
+
device=self.device,
|
|
114
|
+
verbose=False,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Take maximum F1 across references
|
|
118
|
+
max_idx = F1.argmax().item()
|
|
119
|
+
max_precision = P[max_idx].item()
|
|
120
|
+
max_recall = R[max_idx].item()
|
|
121
|
+
max_f1 = F1[max_idx].item()
|
|
122
|
+
|
|
123
|
+
return MetricScore(
|
|
124
|
+
metric_name=self.name,
|
|
125
|
+
value=max_f1, # Use F1 as primary score
|
|
126
|
+
details={
|
|
127
|
+
"precision": max_precision,
|
|
128
|
+
"recall": max_recall,
|
|
129
|
+
"f1": max_f1,
|
|
130
|
+
"model_type": self.model_type or f"auto-{self.lang}",
|
|
131
|
+
"num_references": len(ref_strs),
|
|
132
|
+
"rescaled": self.rescale_with_baseline,
|
|
133
|
+
},
|
|
134
|
+
metadata=metadata or {},
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
__all__ = ["BERTScore"]
|