themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. themis/__init__.py +12 -1
  2. themis/_version.py +2 -2
  3. themis/api.py +343 -0
  4. themis/backends/__init__.py +17 -0
  5. themis/backends/execution.py +197 -0
  6. themis/backends/storage.py +260 -0
  7. themis/cli/__init__.py +5 -0
  8. themis/cli/__main__.py +6 -0
  9. themis/cli/commands/__init__.py +19 -0
  10. themis/cli/commands/benchmarks.py +221 -0
  11. themis/cli/commands/comparison.py +394 -0
  12. themis/cli/commands/config_commands.py +244 -0
  13. themis/cli/commands/cost.py +214 -0
  14. themis/cli/commands/demo.py +68 -0
  15. themis/cli/commands/info.py +90 -0
  16. themis/cli/commands/leaderboard.py +362 -0
  17. themis/cli/commands/math_benchmarks.py +318 -0
  18. themis/cli/commands/mcq_benchmarks.py +207 -0
  19. themis/cli/commands/results.py +252 -0
  20. themis/cli/commands/sample_run.py +244 -0
  21. themis/cli/commands/visualize.py +299 -0
  22. themis/cli/main.py +463 -0
  23. themis/cli/new_project.py +33 -0
  24. themis/cli/utils.py +51 -0
  25. themis/comparison/__init__.py +25 -0
  26. themis/comparison/engine.py +348 -0
  27. themis/comparison/reports.py +283 -0
  28. themis/comparison/statistics.py +402 -0
  29. themis/config/__init__.py +19 -0
  30. themis/config/loader.py +27 -0
  31. themis/config/registry.py +34 -0
  32. themis/config/runtime.py +214 -0
  33. themis/config/schema.py +112 -0
  34. themis/core/__init__.py +5 -0
  35. themis/core/conversation.py +354 -0
  36. themis/core/entities.py +184 -0
  37. themis/core/serialization.py +231 -0
  38. themis/core/tools.py +393 -0
  39. themis/core/types.py +141 -0
  40. themis/datasets/__init__.py +273 -0
  41. themis/datasets/base.py +264 -0
  42. themis/datasets/commonsense_qa.py +174 -0
  43. themis/datasets/competition_math.py +265 -0
  44. themis/datasets/coqa.py +133 -0
  45. themis/datasets/gpqa.py +190 -0
  46. themis/datasets/gsm8k.py +123 -0
  47. themis/datasets/gsm_symbolic.py +124 -0
  48. themis/datasets/math500.py +122 -0
  49. themis/datasets/med_qa.py +179 -0
  50. themis/datasets/medmcqa.py +169 -0
  51. themis/datasets/mmlu_pro.py +262 -0
  52. themis/datasets/piqa.py +146 -0
  53. themis/datasets/registry.py +201 -0
  54. themis/datasets/schema.py +245 -0
  55. themis/datasets/sciq.py +150 -0
  56. themis/datasets/social_i_qa.py +151 -0
  57. themis/datasets/super_gpqa.py +263 -0
  58. themis/evaluation/__init__.py +1 -0
  59. themis/evaluation/conditional.py +410 -0
  60. themis/evaluation/extractors/__init__.py +19 -0
  61. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  62. themis/evaluation/extractors/exceptions.py +7 -0
  63. themis/evaluation/extractors/identity_extractor.py +29 -0
  64. themis/evaluation/extractors/json_field_extractor.py +45 -0
  65. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  66. themis/evaluation/extractors/regex_extractor.py +43 -0
  67. themis/evaluation/math_verify_utils.py +87 -0
  68. themis/evaluation/metrics/__init__.py +21 -0
  69. themis/evaluation/metrics/code/__init__.py +19 -0
  70. themis/evaluation/metrics/code/codebleu.py +144 -0
  71. themis/evaluation/metrics/code/execution.py +280 -0
  72. themis/evaluation/metrics/code/pass_at_k.py +181 -0
  73. themis/evaluation/metrics/composite_metric.py +47 -0
  74. themis/evaluation/metrics/consistency_metric.py +80 -0
  75. themis/evaluation/metrics/exact_match.py +51 -0
  76. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  77. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  78. themis/evaluation/metrics/nlp/__init__.py +21 -0
  79. themis/evaluation/metrics/nlp/bertscore.py +138 -0
  80. themis/evaluation/metrics/nlp/bleu.py +129 -0
  81. themis/evaluation/metrics/nlp/meteor.py +153 -0
  82. themis/evaluation/metrics/nlp/rouge.py +136 -0
  83. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  84. themis/evaluation/metrics/response_length.py +33 -0
  85. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  86. themis/evaluation/pipeline.py +49 -0
  87. themis/evaluation/pipelines/__init__.py +15 -0
  88. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  89. themis/evaluation/pipelines/standard_pipeline.py +348 -0
  90. themis/evaluation/reports.py +293 -0
  91. themis/evaluation/statistics/__init__.py +53 -0
  92. themis/evaluation/statistics/bootstrap.py +79 -0
  93. themis/evaluation/statistics/confidence_intervals.py +121 -0
  94. themis/evaluation/statistics/distributions.py +207 -0
  95. themis/evaluation/statistics/effect_sizes.py +124 -0
  96. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  97. themis/evaluation/statistics/types.py +139 -0
  98. themis/evaluation/strategies/__init__.py +13 -0
  99. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  100. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  101. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  102. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  103. themis/experiment/__init__.py +5 -0
  104. themis/experiment/builder.py +151 -0
  105. themis/experiment/cache_manager.py +134 -0
  106. themis/experiment/comparison.py +631 -0
  107. themis/experiment/cost.py +310 -0
  108. themis/experiment/definitions.py +62 -0
  109. themis/experiment/export.py +798 -0
  110. themis/experiment/export_csv.py +159 -0
  111. themis/experiment/integration_manager.py +104 -0
  112. themis/experiment/math.py +192 -0
  113. themis/experiment/mcq.py +169 -0
  114. themis/experiment/orchestrator.py +415 -0
  115. themis/experiment/pricing.py +317 -0
  116. themis/experiment/storage.py +1458 -0
  117. themis/experiment/visualization.py +588 -0
  118. themis/generation/__init__.py +1 -0
  119. themis/generation/agentic_runner.py +420 -0
  120. themis/generation/batching.py +254 -0
  121. themis/generation/clients.py +143 -0
  122. themis/generation/conversation_runner.py +236 -0
  123. themis/generation/plan.py +456 -0
  124. themis/generation/providers/litellm_provider.py +221 -0
  125. themis/generation/providers/vllm_provider.py +135 -0
  126. themis/generation/router.py +34 -0
  127. themis/generation/runner.py +207 -0
  128. themis/generation/strategies.py +98 -0
  129. themis/generation/templates.py +71 -0
  130. themis/generation/turn_strategies.py +393 -0
  131. themis/generation/types.py +9 -0
  132. themis/integrations/__init__.py +0 -0
  133. themis/integrations/huggingface.py +72 -0
  134. themis/integrations/wandb.py +77 -0
  135. themis/interfaces/__init__.py +169 -0
  136. themis/presets/__init__.py +10 -0
  137. themis/presets/benchmarks.py +354 -0
  138. themis/presets/models.py +190 -0
  139. themis/project/__init__.py +20 -0
  140. themis/project/definitions.py +98 -0
  141. themis/project/patterns.py +230 -0
  142. themis/providers/__init__.py +5 -0
  143. themis/providers/registry.py +39 -0
  144. themis/server/__init__.py +28 -0
  145. themis/server/app.py +337 -0
  146. themis/utils/api_generator.py +379 -0
  147. themis/utils/cost_tracking.py +376 -0
  148. themis/utils/dashboard.py +452 -0
  149. themis/utils/logging_utils.py +41 -0
  150. themis/utils/progress.py +58 -0
  151. themis/utils/tracing.py +320 -0
  152. themis_eval-0.2.0.dist-info/METADATA +596 -0
  153. themis_eval-0.2.0.dist-info/RECORD +157 -0
  154. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
  155. themis_eval-0.1.0.dist-info/METADATA +0 -758
  156. themis_eval-0.1.0.dist-info/RECORD +0 -8
  157. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
  158. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,181 @@
1
+ """Pass@k metric for code generation evaluation.
2
+
3
+ Pass@k measures functional correctness by executing k generated code samples
4
+ and checking if any of them pass the test cases.
5
+
6
+ References:
7
+ Chen et al. (2021). Evaluating Large Language Models Trained on Code.
8
+ (HumanEval paper)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import math
14
+ from typing import Any, Sequence
15
+
16
+ from themis.core.entities import MetricScore
17
+ from themis.interfaces import Metric
18
+
19
+
20
+ def estimate_pass_at_k(n: int, c: int, k: int) -> float:
21
+ """Estimate pass@k using unbiased estimator.
22
+
23
+ This is the standard estimator from the HumanEval paper.
24
+
25
+ Args:
26
+ n: Total number of samples generated
27
+ c: Number of samples that passed
28
+ k: k value for pass@k
29
+
30
+ Returns:
31
+ Estimated pass@k probability
32
+
33
+ Example:
34
+ >>> # Generated 10 samples, 3 passed, compute pass@1
35
+ >>> estimate_pass_at_k(n=10, c=3, k=1)
36
+ 0.3
37
+
38
+ >>> # Generated 100 samples, 30 passed, compute pass@10
39
+ >>> estimate_pass_at_k(n=100, c=30, k=10)
40
+ 0.8926
41
+ """
42
+ if n - c < k:
43
+ return 1.0
44
+
45
+ # Unbiased estimator: 1 - C(n-c, k) / C(n, k)
46
+ # = 1 - product((n-c-i)/(n-i) for i in range(k))
47
+ result = 1.0
48
+ for i in range(k):
49
+ result *= (n - c - i) / (n - i)
50
+
51
+ return 1.0 - result
52
+
53
+
54
+ class PassAtK(Metric):
55
+ """Pass@k metric for code generation.
56
+
57
+ Pass@k measures the probability that at least one of k generated samples
58
+ passes all test cases. It's the standard metric for evaluating code
59
+ generation models like Codex, CodeGen, etc.
60
+
61
+ The metric requires:
62
+ - Multiple samples per problem (num_samples >= k)
63
+ - Test cases to execute against
64
+ - Safe code execution environment
65
+
66
+ Attributes:
67
+ name: Metric identifier ("pass_at_k")
68
+ k: Number of samples to consider
69
+ timeout: Maximum execution time per sample (seconds)
70
+ require_all_tests: Whether all tests must pass (vs any test)
71
+
72
+ Example:
73
+ >>> from themis.evaluation.metrics.code import PassAtK
74
+ >>> metric = PassAtK(k=1)
75
+ >>> score = metric.compute(
76
+ ... prediction={
77
+ ... "samples": ["def add(a, b): return a + b", ...],
78
+ ... "test_results": [True, False, ...],
79
+ ... },
80
+ ... references=[]
81
+ ... )
82
+ >>> print(f"Pass@1: {score.value:.2%}")
83
+ Pass@1: 30.00%
84
+ """
85
+
86
+ requires_reference = False # Uses test execution, not reference matching
87
+
88
+ def __init__(
89
+ self,
90
+ k: int = 1,
91
+ timeout: float = 3.0,
92
+ require_all_tests: bool = True,
93
+ ):
94
+ """Initialize Pass@k metric.
95
+
96
+ Args:
97
+ k: Number of samples for pass@k estimation
98
+ timeout: Maximum execution time per sample (seconds)
99
+ require_all_tests: Whether all test cases must pass (default: True)
100
+ """
101
+ self.name = f"pass_at_{k}"
102
+ self.k = k
103
+ self.timeout = timeout
104
+ self.require_all_tests = require_all_tests
105
+
106
+ def compute(
107
+ self,
108
+ *,
109
+ prediction: Any,
110
+ references: Sequence[Any],
111
+ metadata: dict[str, Any] | None = None,
112
+ ) -> MetricScore:
113
+ """Compute Pass@k score.
114
+
115
+ Args:
116
+ prediction: Dictionary containing:
117
+ - "samples": List of generated code samples
118
+ - "test_results": List of booleans (True if passed)
119
+ - "execution_errors": Optional list of error messages
120
+ references: Not used (test-based evaluation)
121
+ metadata: Optional metadata dict
122
+
123
+ Returns:
124
+ MetricScore with estimated pass@k probability
125
+
126
+ Note:
127
+ The prediction should be prepared by ExecutionAccuracy metric
128
+ or similar execution framework.
129
+ """
130
+ if not isinstance(prediction, dict):
131
+ return MetricScore(
132
+ metric_name=self.name,
133
+ value=0.0,
134
+ details={"error": "Prediction must be dict with samples and test_results"},
135
+ metadata=metadata or {},
136
+ )
137
+
138
+ samples = prediction.get("samples", [])
139
+ test_results = prediction.get("test_results", [])
140
+
141
+ if not samples or not test_results:
142
+ return MetricScore(
143
+ metric_name=self.name,
144
+ value=0.0,
145
+ details={
146
+ "error": "Missing samples or test_results",
147
+ "num_samples": len(samples),
148
+ "num_results": len(test_results),
149
+ },
150
+ metadata=metadata or {},
151
+ )
152
+
153
+ # Count number of samples and passes
154
+ n = len(test_results)
155
+ c = sum(1 for result in test_results if result)
156
+
157
+ # Estimate pass@k
158
+ if n < self.k:
159
+ # Not enough samples, use empirical rate
160
+ pass_at_k = c / n if n > 0 else 0.0
161
+ warning = f"Only {n} samples available for pass@{self.k}"
162
+ else:
163
+ pass_at_k = estimate_pass_at_k(n, c, self.k)
164
+ warning = None
165
+
166
+ return MetricScore(
167
+ metric_name=self.name,
168
+ value=pass_at_k,
169
+ details={
170
+ "k": self.k,
171
+ "n_samples": n,
172
+ "n_passed": c,
173
+ "pass_rate": c / n if n > 0 else 0.0,
174
+ "pass_at_k": pass_at_k,
175
+ "warning": warning,
176
+ },
177
+ metadata=metadata or {},
178
+ )
179
+
180
+
181
+ __all__ = ["PassAtK", "estimate_pass_at_k"]
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Sequence
5
+
6
+ from themis.core import entities as core_entities
7
+ from themis.interfaces import Metric as MetricInterface
8
+
9
+
10
+ @dataclass
11
+ class CompositeMetric(MetricInterface):
12
+ children: Sequence[MetricInterface]
13
+
14
+ def __post_init__(self) -> None:
15
+ self.name = "CompositeMetric"
16
+ self.requires_reference = any(
17
+ getattr(child, "requires_reference", True) for child in self.children
18
+ )
19
+
20
+ def compute(
21
+ self,
22
+ *,
23
+ prediction: Any,
24
+ references: Sequence[Any],
25
+ metadata: dict[str, Any] | None = None,
26
+ ) -> core_entities.MetricScore:
27
+ child_results = [
28
+ child.compute(
29
+ prediction=prediction, references=references, metadata=metadata
30
+ )
31
+ for child in self.children
32
+ ]
33
+ if not child_results:
34
+ return core_entities.MetricScore(
35
+ metric_name=self.name,
36
+ value=0.0,
37
+ details={},
38
+ metadata=dict(metadata or {}),
39
+ )
40
+ value = sum(result.value for result in child_results) / len(child_results)
41
+ details = {result.metric_name: result.details for result in child_results}
42
+ return core_entities.MetricScore(
43
+ metric_name=self.name,
44
+ value=value,
45
+ details=details,
46
+ metadata=dict(metadata or {}),
47
+ )
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Sequence
5
+
6
+ from themis.core import entities as core_entities
7
+ from themis.interfaces import Metric as MetricInterface
8
+
9
+
10
+ def _normalize_text(value: str, case_sensitive: bool, strip_whitespace: bool) -> str:
11
+ if strip_whitespace:
12
+ value = value.strip()
13
+ if not case_sensitive:
14
+ value = value.lower()
15
+ return value
16
+
17
+
18
+ @dataclass
19
+ class ConsistencyMetric(MetricInterface):
20
+ case_sensitive: bool = False
21
+ strip_whitespace: bool = True
22
+
23
+ def __post_init__(self) -> None:
24
+ self.name = "Consistency"
25
+ self.requires_reference = False
26
+
27
+ def compute(
28
+ self,
29
+ *,
30
+ prediction: Any,
31
+ references: Sequence[Any],
32
+ metadata: dict[str, Any] | None = None,
33
+ ) -> core_entities.MetricScore:
34
+ md = dict(metadata or {})
35
+
36
+ outputs: list[str]
37
+ if isinstance(prediction, (list, tuple)):
38
+ outputs = [str(p) for p in prediction]
39
+ else:
40
+ outputs = [str(prediction)]
41
+
42
+ normalized = [
43
+ _normalize_text(text, self.case_sensitive, self.strip_whitespace)
44
+ for text in outputs
45
+ ]
46
+
47
+ majority_correct = None
48
+ reference_text = None
49
+ if references:
50
+ reference_text = _normalize_text(
51
+ str(references[0]), self.case_sensitive, self.strip_whitespace
52
+ )
53
+ correct = [1.0 if out == reference_text else 0.0 for out in normalized]
54
+ majority_correct = sum(correct) / max(1, len(correct))
55
+
56
+ from collections import Counter
57
+
58
+ counter = Counter(normalized)
59
+ mode_count = max(counter.values()) if counter else 0
60
+ agreement = mode_count / max(1, len(normalized))
61
+
62
+ flips = 0
63
+ for i in range(1, len(normalized)):
64
+ if normalized[i] != normalized[i - 1]:
65
+ flips += 1
66
+ flip_rate = flips / max(1, len(normalized) - 1)
67
+
68
+ value = majority_correct if majority_correct is not None else agreement
69
+
70
+ return core_entities.MetricScore(
71
+ metric_name=self.name,
72
+ value=float(value),
73
+ details={
74
+ "agreement": agreement,
75
+ "flip_rate": flip_rate,
76
+ "outputs": outputs,
77
+ "reference": reference_text,
78
+ },
79
+ metadata=md,
80
+ )
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Sequence
5
+
6
+ from themis.core import entities as core_entities
7
+ from themis.interfaces import Metric as MetricInterface
8
+
9
+
10
+ def _normalize_text(value: str, case_sensitive: bool, strip_whitespace: bool) -> str:
11
+ if strip_whitespace:
12
+ value = value.strip()
13
+ if not case_sensitive:
14
+ value = value.lower()
15
+ return value
16
+
17
+
18
+ @dataclass
19
+ class ExactMatch(MetricInterface):
20
+ case_sensitive: bool = False
21
+ strip_whitespace: bool = True
22
+
23
+ def __post_init__(self) -> None:
24
+ self.name = "ExactMatch"
25
+
26
+ def compute(
27
+ self,
28
+ *,
29
+ prediction: Any,
30
+ references: Sequence[Any],
31
+ metadata: dict[str, Any] | None = None,
32
+ ) -> core_entities.MetricScore:
33
+ metadata = dict(metadata or {})
34
+ normalized_prediction = _normalize_text(
35
+ str(prediction), self.case_sensitive, self.strip_whitespace
36
+ )
37
+ matched_reference: str | None = None
38
+ for reference in references:
39
+ normalized_reference = _normalize_text(
40
+ str(reference), self.case_sensitive, self.strip_whitespace
41
+ )
42
+ if normalized_prediction == normalized_reference:
43
+ matched_reference = str(reference)
44
+ break
45
+ value = 1.0 if matched_reference is not None else 0.0
46
+ return core_entities.MetricScore(
47
+ metric_name=self.name,
48
+ value=value,
49
+ details={"matched_reference": matched_reference},
50
+ metadata=metadata,
51
+ )
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Sequence
5
+
6
+ from themis.core import entities as core_entities
7
+ from themis.interfaces import Metric as MetricInterface
8
+
9
+
10
+ @dataclass
11
+ class LengthDifferenceTolerance(MetricInterface):
12
+ max_delta: int = 0
13
+
14
+ def __post_init__(self) -> None:
15
+ self.name = "LengthDifferenceTolerance"
16
+
17
+ def compute(
18
+ self,
19
+ *,
20
+ prediction: Any,
21
+ references: Sequence[Any],
22
+ metadata: dict[str, Any] | None = None,
23
+ ) -> core_entities.MetricScore:
24
+ metadata = dict(metadata or {})
25
+ reference = str(references[0]) if references else ""
26
+ delta = abs(len(str(prediction)) - len(reference))
27
+ value = 1.0 if delta <= self.max_delta else 0.0
28
+ return core_entities.MetricScore(
29
+ metric_name=self.name,
30
+ value=value,
31
+ details={"delta": delta},
32
+ metadata=metadata,
33
+ )
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Sequence
5
+
6
+ from themis.core import entities as core_entities
7
+ from themis.evaluation import math_verify_utils
8
+ from themis.interfaces import Metric as MetricInterface
9
+
10
+
11
+ @dataclass
12
+ class MathVerifyAccuracy(MetricInterface):
13
+ """Numeric equivalence check using math-verify."""
14
+
15
+ def __post_init__(self) -> None:
16
+ math_verify_utils.require_math_verify()
17
+ self.name = "MathVerifyAccuracy"
18
+
19
+ def compute(
20
+ self,
21
+ *,
22
+ prediction: Any,
23
+ references: Sequence[Any],
24
+ metadata: dict[str, Any] | None = None,
25
+ ) -> core_entities.MetricScore:
26
+ math_verify_utils.require_math_verify()
27
+ metadata = dict(metadata or {})
28
+ prediction_expr = math_verify_utils.parse_expression(str(prediction))
29
+ passed = False
30
+ for reference in references:
31
+ reference_expr = math_verify_utils.parse_expression(str(reference))
32
+ if math_verify_utils.verify_expressions(reference_expr, prediction_expr):
33
+ passed = True
34
+ break
35
+ return core_entities.MetricScore(
36
+ metric_name=self.name,
37
+ value=1.0 if passed else 0.0,
38
+ details={"verified": passed},
39
+ metadata=metadata,
40
+ )
@@ -0,0 +1,21 @@
1
+ """NLP evaluation metrics.
2
+
3
+ This module provides standard NLP metrics for text generation evaluation:
4
+ - BLEU: Bilingual Evaluation Understudy for translation quality
5
+ - ROUGE: Recall-Oriented Understudy for Gisting Evaluation for summarization
6
+ - BERTScore: Contextual embeddings-based evaluation
7
+ - METEOR: Metric for Evaluation of Translation with Explicit ORdering
8
+ """
9
+
10
+ from themis.evaluation.metrics.nlp.bleu import BLEU
11
+ from themis.evaluation.metrics.nlp.rouge import ROUGE, ROUGEVariant
12
+ from themis.evaluation.metrics.nlp.bertscore import BERTScore
13
+ from themis.evaluation.metrics.nlp.meteor import METEOR
14
+
15
+ __all__ = [
16
+ "BLEU",
17
+ "ROUGE",
18
+ "ROUGEVariant",
19
+ "BERTScore",
20
+ "METEOR",
21
+ ]
@@ -0,0 +1,138 @@
1
+ """BERTScore metric implementation.
2
+
3
+ BERTScore computes similarity using contextual embeddings from BERT-like models
4
+ instead of exact word matches.
5
+
6
+ References:
7
+ Zhang et al. (2020). BERTScore: Evaluating Text Generation with BERT.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Any, Sequence
13
+
14
+ from themis.core.entities import MetricScore
15
+ from themis.interfaces import Metric
16
+
17
+
18
+ class BERTScore(Metric):
19
+ """BERTScore metric using bert-score library.
20
+
21
+ BERTScore leverages contextual embeddings from pre-trained models (BERT, RoBERTa, etc.)
22
+ to compute semantic similarity between generated and reference texts. It's more
23
+ robust to paraphrasing than exact n-gram matching methods.
24
+
25
+ The metric computes token-level cosine similarity between embeddings and aggregates
26
+ using precision, recall, and F1.
27
+
28
+ Attributes:
29
+ name: Metric identifier ("bertscore")
30
+ model_type: Pre-trained model to use for embeddings
31
+ lang: Language code for automatic model selection
32
+ rescale_with_baseline: Whether to rescale scores using baseline
33
+
34
+ Example:
35
+ >>> from themis.evaluation.metrics.nlp import BERTScore
36
+ >>> metric = BERTScore(model_type="microsoft/deberta-xlarge-mnli")
37
+ >>> score = metric.compute(
38
+ ... prediction="The cat sat on the mat",
39
+ ... references=["A cat is sitting on a mat"]
40
+ ... )
41
+ >>> print(f"BERTScore F1: {score.value:.4f}")
42
+ BERTScore F1: 0.9234
43
+ """
44
+
45
+ requires_reference = True
46
+
47
+ def __init__(
48
+ self,
49
+ model_type: str | None = None,
50
+ lang: str | None = None,
51
+ rescale_with_baseline: bool = True,
52
+ device: str | None = None,
53
+ ):
54
+ """Initialize BERTScore metric.
55
+
56
+ Args:
57
+ model_type: Pre-trained model identifier. Popular choices:
58
+ - "microsoft/deberta-xlarge-mnli" (recommended, large)
59
+ - "microsoft/deberta-large-mnli" (good balance)
60
+ - "roberta-large" (fast, good quality)
61
+ - "bert-base-uncased" (fastest, lower quality)
62
+ lang: Language code (e.g., "en", "zh", "fr"). If provided,
63
+ automatically selects appropriate model.
64
+ rescale_with_baseline: Whether to rescale scores using baseline
65
+ (recommended for human correlation)
66
+ device: Device to use ("cuda", "cpu", or None for auto-detect)
67
+ """
68
+ self.name = "bertscore"
69
+ self.model_type = model_type
70
+ self.lang = lang
71
+ self.rescale_with_baseline = rescale_with_baseline
72
+ self.device = device
73
+
74
+ # Lazy import bert-score (not required for all users)
75
+ try:
76
+ import bert_score
77
+ self._bert_score = bert_score
78
+ except ImportError:
79
+ raise ImportError(
80
+ "bert-score is required for BERTScore metric. "
81
+ "Install it with: pip install bert-score"
82
+ )
83
+
84
+ def compute(
85
+ self,
86
+ *,
87
+ prediction: Any,
88
+ references: Sequence[Any],
89
+ metadata: dict[str, Any] | None = None,
90
+ ) -> MetricScore:
91
+ """Compute BERTScore.
92
+
93
+ Args:
94
+ prediction: Generated text (already extracted by pipeline)
95
+ references: List of reference texts
96
+ metadata: Optional metadata dict
97
+
98
+ Returns:
99
+ MetricScore with BERTScore F1 and precision/recall details
100
+ """
101
+ # Convert to strings
102
+ pred_str = str(prediction)
103
+ ref_strs = [str(ref) for ref in references]
104
+
105
+ # Compute BERTScore
106
+ # Note: bert_score.score expects lists of predictions and references
107
+ P, R, F1 = self._bert_score.score(
108
+ [pred_str] * len(ref_strs), # Repeat prediction for each reference
109
+ ref_strs,
110
+ model_type=self.model_type,
111
+ lang=self.lang,
112
+ rescale_with_baseline=self.rescale_with_baseline,
113
+ device=self.device,
114
+ verbose=False,
115
+ )
116
+
117
+ # Take maximum F1 across references
118
+ max_idx = F1.argmax().item()
119
+ max_precision = P[max_idx].item()
120
+ max_recall = R[max_idx].item()
121
+ max_f1 = F1[max_idx].item()
122
+
123
+ return MetricScore(
124
+ metric_name=self.name,
125
+ value=max_f1, # Use F1 as primary score
126
+ details={
127
+ "precision": max_precision,
128
+ "recall": max_recall,
129
+ "f1": max_f1,
130
+ "model_type": self.model_type or f"auto-{self.lang}",
131
+ "num_references": len(ref_strs),
132
+ "rescaled": self.rescale_with_baseline,
133
+ },
134
+ metadata=metadata or {},
135
+ )
136
+
137
+
138
+ __all__ = ["BERTScore"]