themis-eval 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +429 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/main.py +427 -57
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/core/entities.py +23 -3
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/pipelines/standard_pipeline.py +68 -8
- themis/experiment/cache_manager.py +8 -3
- themis/experiment/export.py +110 -2
- themis/experiment/orchestrator.py +109 -11
- themis/experiment/storage.py +1457 -110
- themis/generation/providers/litellm_provider.py +46 -0
- themis/generation/runner.py +22 -6
- themis/integrations/huggingface.py +12 -1
- themis/integrations/wandb.py +13 -1
- themis/interfaces/__init__.py +86 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis_eval-0.2.1.dist-info/METADATA +596 -0
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.1.dist-info}/RECORD +42 -19
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.1.dist-info}/WHEEL +1 -1
- themis_eval-0.1.1.dist-info/METADATA +0 -758
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.1.dist-info → themis_eval-0.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""BLEU (Bilingual Evaluation Understudy) metric implementation.
|
|
2
|
+
|
|
3
|
+
BLEU measures the similarity between generated text and reference translations
|
|
4
|
+
using n-gram precision with brevity penalty.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
Papineni et al. (2002). BLEU: a Method for Automatic Evaluation of Machine Translation.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any, Sequence
|
|
13
|
+
|
|
14
|
+
from themis.core.entities import MetricScore
|
|
15
|
+
from themis.interfaces import Metric
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BLEU(Metric):
|
|
19
|
+
"""BLEU metric using sacrebleu library.
|
|
20
|
+
|
|
21
|
+
BLEU is a precision-based metric that computes n-gram overlap between
|
|
22
|
+
generated text and reference translations. It includes a brevity penalty
|
|
23
|
+
to penalize short translations.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
name: Metric identifier ("bleu")
|
|
27
|
+
lowercase: Whether to lowercase text before scoring
|
|
28
|
+
tokenize: Tokenization method ("13a", "intl", "zh", "ja-mecab", etc.)
|
|
29
|
+
max_ngram_order: Maximum n-gram order (default: 4)
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
>>> from themis.evaluation.metrics.nlp import BLEU
|
|
33
|
+
>>> metric = BLEU()
|
|
34
|
+
>>> score = metric.compute(
|
|
35
|
+
... prediction="The cat sat on the mat",
|
|
36
|
+
... references=["The cat is on the mat", "A cat is sitting on a mat"]
|
|
37
|
+
... )
|
|
38
|
+
>>> print(f"BLEU: {score.value:.4f}")
|
|
39
|
+
BLEU: 0.4523
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
requires_reference = True
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
lowercase: bool = False,
|
|
47
|
+
tokenize: str = "13a",
|
|
48
|
+
max_ngram_order: int = 4,
|
|
49
|
+
):
|
|
50
|
+
"""Initialize BLEU metric.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
lowercase: Convert text to lowercase before scoring
|
|
54
|
+
tokenize: Tokenization method:
|
|
55
|
+
- "13a": Default Moses tokenizer (punctuation split)
|
|
56
|
+
- "intl": International tokenizer
|
|
57
|
+
- "zh": Chinese tokenizer
|
|
58
|
+
- "ja-mecab": Japanese MeCab tokenizer
|
|
59
|
+
- "none": No tokenization
|
|
60
|
+
max_ngram_order: Maximum n-gram order (typically 4)
|
|
61
|
+
"""
|
|
62
|
+
self.name = "bleu"
|
|
63
|
+
self.lowercase = lowercase
|
|
64
|
+
self.tokenize = tokenize
|
|
65
|
+
self.max_ngram_order = max_ngram_order
|
|
66
|
+
|
|
67
|
+
# Lazy import sacrebleu (not required for all users)
|
|
68
|
+
try:
|
|
69
|
+
from sacrebleu import BLEU as SacreBLEU
|
|
70
|
+
self._scorer = SacreBLEU(
|
|
71
|
+
lowercase=lowercase,
|
|
72
|
+
tokenize=tokenize,
|
|
73
|
+
max_ngram_order=max_ngram_order,
|
|
74
|
+
)
|
|
75
|
+
except ImportError:
|
|
76
|
+
raise ImportError(
|
|
77
|
+
"sacrebleu is required for BLEU metric. "
|
|
78
|
+
"Install it with: pip install sacrebleu"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def compute(
|
|
82
|
+
self,
|
|
83
|
+
*,
|
|
84
|
+
prediction: Any,
|
|
85
|
+
references: Sequence[Any],
|
|
86
|
+
metadata: dict[str, Any] | None = None,
|
|
87
|
+
) -> MetricScore:
|
|
88
|
+
"""Compute BLEU score.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
prediction: Generated text (already extracted by pipeline)
|
|
92
|
+
references: List of reference translations
|
|
93
|
+
metadata: Optional metadata dict
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
MetricScore with BLEU value (0.0-1.0) and detailed scores
|
|
97
|
+
"""
|
|
98
|
+
# Convert to strings
|
|
99
|
+
pred_str = str(prediction)
|
|
100
|
+
ref_strs = [str(ref) for ref in references]
|
|
101
|
+
|
|
102
|
+
# Compute BLEU score
|
|
103
|
+
score_obj = self._scorer.sentence_score(pred_str, ref_strs)
|
|
104
|
+
|
|
105
|
+
# Extract scores (sacrebleu returns 0-100, we normalize to 0-1)
|
|
106
|
+
bleu_score = score_obj.score / 100.0
|
|
107
|
+
|
|
108
|
+
# Extract precision scores for each n-gram
|
|
109
|
+
precisions = [p / 100.0 for p in score_obj.precisions]
|
|
110
|
+
|
|
111
|
+
return MetricScore(
|
|
112
|
+
metric_name=self.name,
|
|
113
|
+
value=bleu_score,
|
|
114
|
+
details={
|
|
115
|
+
"bleu_score": bleu_score,
|
|
116
|
+
"precision_1": precisions[0] if len(precisions) > 0 else 0.0,
|
|
117
|
+
"precision_2": precisions[1] if len(precisions) > 1 else 0.0,
|
|
118
|
+
"precision_3": precisions[2] if len(precisions) > 2 else 0.0,
|
|
119
|
+
"precision_4": precisions[3] if len(precisions) > 3 else 0.0,
|
|
120
|
+
"brevity_penalty": score_obj.bp,
|
|
121
|
+
"length_ratio": score_obj.sys_len / score_obj.ref_len if score_obj.ref_len > 0 else 0.0,
|
|
122
|
+
"sys_len": score_obj.sys_len,
|
|
123
|
+
"ref_len": score_obj.ref_len,
|
|
124
|
+
},
|
|
125
|
+
metadata=metadata or {},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
__all__ = ["BLEU"]
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""METEOR (Metric for Evaluation of Translation with Explicit ORdering) metric.
|
|
2
|
+
|
|
3
|
+
METEOR is an MT evaluation metric that addresses some weaknesses of BLEU by
|
|
4
|
+
incorporating stemming, synonymy, and explicit word ordering.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
Banerjee & Lavie (2005). METEOR: An Automatic Metric for MT Evaluation
|
|
8
|
+
with Improved Correlation with Human Judgments.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any, Sequence
|
|
14
|
+
|
|
15
|
+
from themis.core.entities import MetricScore
|
|
16
|
+
from themis.interfaces import Metric
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class METEOR(Metric):
|
|
20
|
+
"""METEOR metric using nltk library.
|
|
21
|
+
|
|
22
|
+
METEOR compares generated text to references using:
|
|
23
|
+
- Exact word matching
|
|
24
|
+
- Stemming (using Porter stemmer)
|
|
25
|
+
- Synonymy (using WordNet)
|
|
26
|
+
- Word order (using chunk matching)
|
|
27
|
+
|
|
28
|
+
It computes a weighted F-score with emphasis on recall and applies a penalty
|
|
29
|
+
for word order differences.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
name: Metric identifier ("meteor")
|
|
33
|
+
alpha: Weight for precision vs recall (default: 0.9, favors recall)
|
|
34
|
+
beta: Weight for fragmentation penalty (default: 3.0)
|
|
35
|
+
gamma: Fragmentation penalty coefficient (default: 0.5)
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> from themis.evaluation.metrics.nlp import METEOR
|
|
39
|
+
>>> metric = METEOR()
|
|
40
|
+
>>> score = metric.compute(
|
|
41
|
+
... prediction="The cat sat on the mat",
|
|
42
|
+
... references=["The cat is on the mat", "A cat sits on a mat"]
|
|
43
|
+
... )
|
|
44
|
+
>>> print(f"METEOR: {score.value:.4f}")
|
|
45
|
+
METEOR: 0.8234
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
requires_reference = True
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
alpha: float = 0.9,
|
|
53
|
+
beta: float = 3.0,
|
|
54
|
+
gamma: float = 0.5,
|
|
55
|
+
):
|
|
56
|
+
"""Initialize METEOR metric.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
alpha: Weight for precision vs recall (0-1). Higher values favor recall.
|
|
60
|
+
Default 0.9 emphasizes recall like original METEOR.
|
|
61
|
+
beta: Weight for fragmentation penalty (typically 3.0)
|
|
62
|
+
gamma: Fragmentation penalty coefficient (typically 0.5)
|
|
63
|
+
"""
|
|
64
|
+
self.name = "meteor"
|
|
65
|
+
self.alpha = alpha
|
|
66
|
+
self.beta = beta
|
|
67
|
+
self.gamma = gamma
|
|
68
|
+
|
|
69
|
+
# Lazy import nltk (not required for all users)
|
|
70
|
+
try:
|
|
71
|
+
from nltk.translate import meteor_score as meteor
|
|
72
|
+
self._meteor = meteor
|
|
73
|
+
|
|
74
|
+
# Download required NLTK data if not present
|
|
75
|
+
import nltk
|
|
76
|
+
try:
|
|
77
|
+
nltk.data.find('corpora/wordnet')
|
|
78
|
+
except LookupError:
|
|
79
|
+
print("Downloading WordNet data for METEOR...")
|
|
80
|
+
nltk.download('wordnet', quiet=True)
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
nltk.data.find('omw-1.4')
|
|
84
|
+
except LookupError:
|
|
85
|
+
print("Downloading OMW data for METEOR...")
|
|
86
|
+
nltk.download('omw-1.4', quiet=True)
|
|
87
|
+
|
|
88
|
+
except ImportError:
|
|
89
|
+
raise ImportError(
|
|
90
|
+
"nltk is required for METEOR metric. "
|
|
91
|
+
"Install it with: pip install nltk"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def compute(
|
|
95
|
+
self,
|
|
96
|
+
*,
|
|
97
|
+
prediction: Any,
|
|
98
|
+
references: Sequence[Any],
|
|
99
|
+
metadata: dict[str, Any] | None = None,
|
|
100
|
+
) -> MetricScore:
|
|
101
|
+
"""Compute METEOR score.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
prediction: Generated text (already extracted by pipeline)
|
|
105
|
+
references: List of reference texts
|
|
106
|
+
metadata: Optional metadata dict
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
MetricScore with METEOR value (0.0-1.0)
|
|
110
|
+
"""
|
|
111
|
+
# Convert to strings and tokenize
|
|
112
|
+
pred_str = str(prediction)
|
|
113
|
+
ref_strs = [str(ref) for ref in references]
|
|
114
|
+
|
|
115
|
+
# Tokenize (simple whitespace tokenization)
|
|
116
|
+
pred_tokens = pred_str.split()
|
|
117
|
+
ref_tokens_list = [ref.split() for ref in ref_strs]
|
|
118
|
+
|
|
119
|
+
# Compute METEOR score
|
|
120
|
+
# Note: nltk's meteor_score takes one reference at a time
|
|
121
|
+
# We compute for each reference and take the maximum
|
|
122
|
+
max_score = 0.0
|
|
123
|
+
|
|
124
|
+
for ref_tokens in ref_tokens_list:
|
|
125
|
+
try:
|
|
126
|
+
score = self._meteor.meteor_score(
|
|
127
|
+
[ref_tokens], # References should be list of tokenized references
|
|
128
|
+
pred_tokens, # Hypothesis is tokenized prediction
|
|
129
|
+
alpha=self.alpha,
|
|
130
|
+
beta=self.beta,
|
|
131
|
+
gamma=self.gamma,
|
|
132
|
+
)
|
|
133
|
+
max_score = max(max_score, score)
|
|
134
|
+
except Exception as e:
|
|
135
|
+
# Handle edge cases (empty strings, etc.)
|
|
136
|
+
print(f"Warning: METEOR computation failed: {e}")
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
return MetricScore(
|
|
140
|
+
metric_name=self.name,
|
|
141
|
+
value=max_score,
|
|
142
|
+
details={
|
|
143
|
+
"meteor_score": max_score,
|
|
144
|
+
"num_references": len(ref_strs),
|
|
145
|
+
"alpha": self.alpha,
|
|
146
|
+
"beta": self.beta,
|
|
147
|
+
"gamma": self.gamma,
|
|
148
|
+
},
|
|
149
|
+
metadata=metadata or {},
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
__all__ = ["METEOR"]
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric.
|
|
2
|
+
|
|
3
|
+
ROUGE measures overlap between generated text and reference summaries
|
|
4
|
+
using n-grams and longest common subsequence.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
Lin (2004). ROUGE: A Package for Automatic Evaluation of Summaries.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Any, Sequence
|
|
14
|
+
|
|
15
|
+
from themis.core.entities import MetricScore
|
|
16
|
+
from themis.interfaces import Metric
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ROUGEVariant(str, Enum):
|
|
20
|
+
"""ROUGE metric variants."""
|
|
21
|
+
|
|
22
|
+
ROUGE_1 = "rouge1" # Unigram overlap
|
|
23
|
+
ROUGE_2 = "rouge2" # Bigram overlap
|
|
24
|
+
ROUGE_L = "rougeL" # Longest common subsequence
|
|
25
|
+
ROUGE_L_SUM = "rougeLsum" # LCS with summary-level computation
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ROUGE(Metric):
|
|
29
|
+
"""ROUGE metric using rouge-score library.
|
|
30
|
+
|
|
31
|
+
ROUGE is a recall-oriented metric that measures n-gram overlap between
|
|
32
|
+
generated text and reference summaries. It's commonly used for evaluating
|
|
33
|
+
text summarization and text generation tasks.
|
|
34
|
+
|
|
35
|
+
Variants:
|
|
36
|
+
- ROUGE-1: Unigram overlap
|
|
37
|
+
- ROUGE-2: Bigram overlap
|
|
38
|
+
- ROUGE-L: Longest common subsequence (sentence-level)
|
|
39
|
+
- ROUGE-Lsum: Longest common subsequence (summary-level)
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
name: Metric identifier (e.g., "rouge1", "rouge2", "rougeL")
|
|
43
|
+
variant: Which ROUGE variant to compute
|
|
44
|
+
use_stemmer: Whether to use Porter stemmer
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> from themis.evaluation.metrics.nlp import ROUGE, ROUGEVariant
|
|
48
|
+
>>> metric = ROUGE(variant=ROUGEVariant.ROUGE_2)
|
|
49
|
+
>>> score = metric.compute(
|
|
50
|
+
... prediction="The quick brown fox jumps over the lazy dog",
|
|
51
|
+
... references=["A quick brown fox jumped over a lazy dog"]
|
|
52
|
+
... )
|
|
53
|
+
>>> print(f"ROUGE-2 F1: {score.value:.4f}")
|
|
54
|
+
ROUGE-2 F1: 0.6154
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
requires_reference = True
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
variant: ROUGEVariant = ROUGEVariant.ROUGE_L,
|
|
62
|
+
use_stemmer: bool = True,
|
|
63
|
+
):
|
|
64
|
+
"""Initialize ROUGE metric.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
variant: Which ROUGE variant to compute
|
|
68
|
+
use_stemmer: Whether to use Porter stemmer for word matching
|
|
69
|
+
"""
|
|
70
|
+
self.variant = variant
|
|
71
|
+
self.use_stemmer = use_stemmer
|
|
72
|
+
self.name = variant.value
|
|
73
|
+
|
|
74
|
+
# Lazy import rouge-score (not required for all users)
|
|
75
|
+
try:
|
|
76
|
+
from rouge_score import rouge_scorer
|
|
77
|
+
self._scorer = rouge_scorer.RougeScorer(
|
|
78
|
+
[variant.value],
|
|
79
|
+
use_stemmer=use_stemmer,
|
|
80
|
+
)
|
|
81
|
+
except ImportError:
|
|
82
|
+
raise ImportError(
|
|
83
|
+
"rouge-score is required for ROUGE metric. "
|
|
84
|
+
"Install it with: pip install rouge-score"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def compute(
|
|
88
|
+
self,
|
|
89
|
+
*,
|
|
90
|
+
prediction: Any,
|
|
91
|
+
references: Sequence[Any],
|
|
92
|
+
metadata: dict[str, Any] | None = None,
|
|
93
|
+
) -> MetricScore:
|
|
94
|
+
"""Compute ROUGE score.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
prediction: Generated text (already extracted by pipeline)
|
|
98
|
+
references: List of reference summaries
|
|
99
|
+
metadata: Optional metadata dict
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
MetricScore with ROUGE F1 score and precision/recall details
|
|
103
|
+
"""
|
|
104
|
+
# Convert to strings
|
|
105
|
+
pred_str = str(prediction)
|
|
106
|
+
ref_strs = [str(ref) for ref in references]
|
|
107
|
+
|
|
108
|
+
# Compute ROUGE for each reference and take the maximum
|
|
109
|
+
max_precision = 0.0
|
|
110
|
+
max_recall = 0.0
|
|
111
|
+
max_f1 = 0.0
|
|
112
|
+
|
|
113
|
+
for ref_str in ref_strs:
|
|
114
|
+
scores = self._scorer.score(ref_str, pred_str)
|
|
115
|
+
rouge_score = scores[self.variant.value]
|
|
116
|
+
|
|
117
|
+
if rouge_score.fmeasure > max_f1:
|
|
118
|
+
max_precision = rouge_score.precision
|
|
119
|
+
max_recall = rouge_score.recall
|
|
120
|
+
max_f1 = rouge_score.fmeasure
|
|
121
|
+
|
|
122
|
+
return MetricScore(
|
|
123
|
+
metric_name=self.name,
|
|
124
|
+
value=max_f1, # Use F1 as primary score
|
|
125
|
+
details={
|
|
126
|
+
"precision": max_precision,
|
|
127
|
+
"recall": max_recall,
|
|
128
|
+
"f1": max_f1,
|
|
129
|
+
"variant": self.variant.value,
|
|
130
|
+
"num_references": len(ref_strs),
|
|
131
|
+
},
|
|
132
|
+
metadata=metadata or {},
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
__all__ = ["ROUGE", "ROUGEVariant"]
|
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import time
|
|
7
|
+
import warnings
|
|
7
8
|
from typing import Callable, Sequence
|
|
8
9
|
|
|
9
10
|
from themis.core import entities as core_entities
|
|
@@ -35,19 +36,49 @@ def _default_reference_selector(record: core_entities.GenerationRecord):
|
|
|
35
36
|
return reference.value
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
def _normalize_references(reference):
|
|
39
|
-
"""Normalize reference to list format.
|
|
39
|
+
def _normalize_references(reference) -> list:
|
|
40
|
+
"""Normalize reference to list format for metric consumption.
|
|
41
|
+
|
|
42
|
+
This function converts various reference formats into a standardized list
|
|
43
|
+
that metrics can reliably consume. The normalized format is always a list
|
|
44
|
+
where each element represents one reference value.
|
|
40
45
|
|
|
41
46
|
Args:
|
|
42
|
-
reference: Reference value
|
|
47
|
+
reference: Reference value in various formats:
|
|
48
|
+
- Reference object: Extracts .value field
|
|
49
|
+
- dict: Kept as-is in a list (for multi-value references)
|
|
50
|
+
- list/tuple: Returned as list
|
|
51
|
+
- scalar: Wrapped in a list
|
|
43
52
|
|
|
44
53
|
Returns:
|
|
45
|
-
List of
|
|
54
|
+
List of reference values. Each element can be:
|
|
55
|
+
- A scalar value (str, int, float, bool)
|
|
56
|
+
- A dict (for multi-value references like {"target": 122, "numbers": [...]})
|
|
57
|
+
- Any other type from the original reference
|
|
58
|
+
|
|
59
|
+
Examples:
|
|
60
|
+
>>> _normalize_references(Reference(kind="answer", value="42"))
|
|
61
|
+
["42"]
|
|
62
|
+
|
|
63
|
+
>>> _normalize_references(Reference(kind="task", value={"target": 122, "numbers": [25, 50]}))
|
|
64
|
+
[{"target": 122, "numbers": [25, 50]}]
|
|
65
|
+
|
|
66
|
+
>>> _normalize_references(["yes", "no", "maybe"])
|
|
67
|
+
["yes", "no", "maybe"]
|
|
68
|
+
|
|
69
|
+
>>> _normalize_references("42")
|
|
70
|
+
["42"]
|
|
71
|
+
|
|
72
|
+
Note:
|
|
73
|
+
Metrics receive references in this normalized format and should handle
|
|
74
|
+
both simple values and dict values appropriately.
|
|
46
75
|
"""
|
|
47
76
|
if isinstance(reference, core_entities.Reference):
|
|
48
77
|
reference = reference.value
|
|
49
78
|
if isinstance(reference, list):
|
|
50
79
|
return reference
|
|
80
|
+
if isinstance(reference, tuple):
|
|
81
|
+
return list(reference)
|
|
51
82
|
return [reference]
|
|
52
83
|
|
|
53
84
|
|
|
@@ -89,12 +120,21 @@ class EvaluationPipeline:
|
|
|
89
120
|
Args:
|
|
90
121
|
extractor: Extractor for parsing model output
|
|
91
122
|
metrics: List of metrics to compute
|
|
92
|
-
reference_selector: Optional function to extract reference
|
|
93
|
-
|
|
123
|
+
reference_selector: Optional function to extract reference from record.
|
|
124
|
+
If provided, this takes precedence over item.reference from strategies.
|
|
125
|
+
strategy_resolver: Optional function to resolve evaluation strategy.
|
|
126
|
+
If using a custom reference_selector with DefaultEvaluationStrategy,
|
|
127
|
+
the selector will take precedence.
|
|
128
|
+
|
|
129
|
+
Note:
|
|
130
|
+
When using DefaultEvaluationStrategy with a custom reference_selector,
|
|
131
|
+
the reference_selector will override the default behavior. Consider
|
|
132
|
+
using a custom strategy if you need more control over reference selection.
|
|
94
133
|
"""
|
|
95
134
|
self._extractor = extractor
|
|
96
135
|
self._metrics = list(metrics)
|
|
97
|
-
self._reference_selector = reference_selector
|
|
136
|
+
self._reference_selector = reference_selector
|
|
137
|
+
self._has_custom_reference_selector = reference_selector is not None
|
|
98
138
|
self._strategy_resolver = strategy_resolver or (
|
|
99
139
|
lambda record: evaluation_strategies.DefaultEvaluationStrategy()
|
|
100
140
|
)
|
|
@@ -102,6 +142,17 @@ class EvaluationPipeline:
|
|
|
102
142
|
tuple[str, Callable[[core_entities.GenerationRecord], bool]]
|
|
103
143
|
] = []
|
|
104
144
|
|
|
145
|
+
# Validation: warn if custom reference_selector is used with default strategy
|
|
146
|
+
if self._has_custom_reference_selector and strategy_resolver is None:
|
|
147
|
+
warnings.warn(
|
|
148
|
+
"Custom reference_selector provided without custom strategy_resolver. "
|
|
149
|
+
"The reference_selector will take precedence over DefaultEvaluationStrategy's "
|
|
150
|
+
"reference handling. If you need more control, consider providing a custom "
|
|
151
|
+
"strategy_resolver that sets reference=None in EvaluationItem.",
|
|
152
|
+
UserWarning,
|
|
153
|
+
stacklevel=2,
|
|
154
|
+
)
|
|
155
|
+
|
|
105
156
|
def evaluate(
|
|
106
157
|
self, records: Sequence[core_entities.GenerationRecord]
|
|
107
158
|
) -> EvaluationReport:
|
|
@@ -167,7 +218,16 @@ class EvaluationPipeline:
|
|
|
167
218
|
record_failures.append(message)
|
|
168
219
|
continue
|
|
169
220
|
|
|
170
|
-
|
|
221
|
+
# CRITICAL: Always call reference_selector if provided (takes precedence)
|
|
222
|
+
# This fixes the issue where DefaultEvaluationStrategy's reference
|
|
223
|
+
# would prevent custom reference_selector from being called
|
|
224
|
+
if self._has_custom_reference_selector:
|
|
225
|
+
reference = self._reference_selector(record)
|
|
226
|
+
elif item.reference is not None:
|
|
227
|
+
reference = item.reference
|
|
228
|
+
else:
|
|
229
|
+
reference = _default_reference_selector(record)
|
|
230
|
+
|
|
171
231
|
references = (
|
|
172
232
|
_normalize_references(reference)
|
|
173
233
|
if reference is not None
|
|
@@ -65,18 +65,21 @@ class CacheManager:
|
|
|
65
65
|
return {}
|
|
66
66
|
return self._storage.load_cached_records(run_id)
|
|
67
67
|
|
|
68
|
-
def load_cached_evaluations(
|
|
68
|
+
def load_cached_evaluations(
|
|
69
|
+
self, run_id: str, evaluation_config: dict | None = None
|
|
70
|
+
) -> dict[str, EvaluationRecord]:
|
|
69
71
|
"""Load cached evaluation records for resuming.
|
|
70
72
|
|
|
71
73
|
Args:
|
|
72
74
|
run_id: Unique run identifier
|
|
75
|
+
evaluation_config: Evaluation configuration (metrics, extractor) for cache matching
|
|
73
76
|
|
|
74
77
|
Returns:
|
|
75
78
|
Dictionary mapping cache keys to evaluation records
|
|
76
79
|
"""
|
|
77
80
|
if not self._enable_resume or self._storage is None:
|
|
78
81
|
return {}
|
|
79
|
-
return self._storage.load_cached_evaluations(run_id)
|
|
82
|
+
return self._storage.load_cached_evaluations(run_id, evaluation_config=evaluation_config)
|
|
80
83
|
|
|
81
84
|
def save_generation_record(
|
|
82
85
|
self,
|
|
@@ -99,6 +102,7 @@ class CacheManager:
|
|
|
99
102
|
run_id: str,
|
|
100
103
|
generation_record: GenerationRecord,
|
|
101
104
|
evaluation_record: EvaluationRecord,
|
|
105
|
+
evaluation_config: dict | None = None,
|
|
102
106
|
) -> None:
|
|
103
107
|
"""Save a single evaluation record.
|
|
104
108
|
|
|
@@ -106,10 +110,11 @@ class CacheManager:
|
|
|
106
110
|
run_id: Unique run identifier
|
|
107
111
|
generation_record: Corresponding generation record
|
|
108
112
|
evaluation_record: Evaluation record to save
|
|
113
|
+
evaluation_config: Evaluation configuration for cache invalidation
|
|
109
114
|
"""
|
|
110
115
|
if self._storage is not None and self._enable_cache:
|
|
111
116
|
self._storage.append_evaluation(
|
|
112
|
-
run_id, generation_record, evaluation_record
|
|
117
|
+
run_id, generation_record, evaluation_record, evaluation_config=evaluation_config
|
|
113
118
|
)
|
|
114
119
|
|
|
115
120
|
def get_run_path(self, run_id: str) -> str | None:
|