bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
"""Metrics computation using HuggingFace evaluate library.
|
|
2
|
+
|
|
3
|
+
This module provides metric computation functions for use with HuggingFace Trainer.
|
|
4
|
+
It uses the evaluate library for standardized, well-tested metrics.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
import evaluate
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from transformers import EvalPrediction, PreTrainedTokenizerBase
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def compute_binary_metrics(eval_pred: EvalPrediction) -> dict[str, float]:
|
|
19
|
+
"""Compute metrics for binary classification tasks.
|
|
20
|
+
|
|
21
|
+
Uses HuggingFace evaluate library for accuracy, precision, recall, and F1.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
eval_pred : EvalPrediction
|
|
26
|
+
EvalPrediction object with predictions and label_ids attributes.
|
|
27
|
+
predictions: array of shape (n_samples,) with logits
|
|
28
|
+
label_ids: array of shape (n_samples,) with true labels (0 or 1)
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
dict[str, float]
|
|
33
|
+
Dictionary with accuracy, precision, recall, and f1 metrics.
|
|
34
|
+
|
|
35
|
+
Examples
|
|
36
|
+
--------
|
|
37
|
+
>>> from transformers import EvalPrediction
|
|
38
|
+
>>> import numpy as np
|
|
39
|
+
>>> predictions = np.array([0.8, 0.3, 0.9, 0.2]) # Logits
|
|
40
|
+
>>> labels = np.array([1.0, 0.0, 1.0, 0.0])
|
|
41
|
+
>>> eval_pred = EvalPrediction(predictions=predictions, label_ids=labels)
|
|
42
|
+
>>> metrics = compute_binary_metrics(eval_pred)
|
|
43
|
+
>>> "accuracy" in metrics
|
|
44
|
+
True
|
|
45
|
+
"""
|
|
46
|
+
# Load metrics from evaluate library
|
|
47
|
+
accuracy_metric = evaluate.load("accuracy")
|
|
48
|
+
precision_metric = evaluate.load("precision")
|
|
49
|
+
recall_metric = evaluate.load("recall")
|
|
50
|
+
f1_metric = evaluate.load("f1")
|
|
51
|
+
|
|
52
|
+
# Extract predictions and labels
|
|
53
|
+
predictions = eval_pred.predictions
|
|
54
|
+
labels = eval_pred.label_ids
|
|
55
|
+
|
|
56
|
+
# Convert logits to predictions (binary: apply sigmoid and threshold)
|
|
57
|
+
if predictions.ndim == 1:
|
|
58
|
+
# Single logit per sample
|
|
59
|
+
preds = (1 / (1 + np.exp(-predictions)) > 0.5).astype(int)
|
|
60
|
+
else:
|
|
61
|
+
# Multiple logits (shouldn't happen for binary, but handle it)
|
|
62
|
+
preds = np.argmax(predictions, axis=-1)
|
|
63
|
+
|
|
64
|
+
# Ensure labels are integers
|
|
65
|
+
labels = labels.astype(int)
|
|
66
|
+
|
|
67
|
+
# Compute metrics
|
|
68
|
+
accuracy = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
|
|
69
|
+
precision = precision_metric.compute(
|
|
70
|
+
predictions=preds, references=labels, average="binary", zero_division=0
|
|
71
|
+
)["precision"]
|
|
72
|
+
recall = recall_metric.compute(
|
|
73
|
+
predictions=preds, references=labels, average="binary", zero_division=0
|
|
74
|
+
)["recall"]
|
|
75
|
+
f1 = f1_metric.compute(
|
|
76
|
+
predictions=preds, references=labels, average="binary", zero_division=0
|
|
77
|
+
)["f1"]
|
|
78
|
+
|
|
79
|
+
return {
|
|
80
|
+
"accuracy": accuracy,
|
|
81
|
+
"precision": precision,
|
|
82
|
+
"recall": recall,
|
|
83
|
+
"f1": f1,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def compute_regression_metrics(eval_pred: EvalPrediction) -> dict[str, float]:
|
|
88
|
+
"""Compute metrics for regression tasks.
|
|
89
|
+
|
|
90
|
+
Uses HuggingFace evaluate library for MSE, MAE, and R².
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
eval_pred : EvalPrediction
|
|
95
|
+
EvalPrediction object with predictions and label_ids attributes.
|
|
96
|
+
predictions: array of shape (n_samples, 1) with continuous values
|
|
97
|
+
label_ids: array of shape (n_samples,) with true continuous values
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
dict[str, float]
|
|
102
|
+
Dictionary with mse, mae, and r2 metrics.
|
|
103
|
+
|
|
104
|
+
Examples
|
|
105
|
+
--------
|
|
106
|
+
>>> from transformers import EvalPrediction
|
|
107
|
+
>>> import numpy as np
|
|
108
|
+
>>> predictions = np.array([[250.5], [300.2], [275.0]]) # Continuous values
|
|
109
|
+
>>> labels = np.array([250.0, 300.0, 275.0])
|
|
110
|
+
>>> eval_pred = EvalPrediction(predictions=predictions, label_ids=labels)
|
|
111
|
+
>>> metrics = compute_regression_metrics(eval_pred)
|
|
112
|
+
>>> "mse" in metrics
|
|
113
|
+
True
|
|
114
|
+
"""
|
|
115
|
+
# Load metrics from evaluate library
|
|
116
|
+
mse_metric = evaluate.load("mse")
|
|
117
|
+
mae_metric = evaluate.load("mae")
|
|
118
|
+
|
|
119
|
+
# Extract predictions and labels
|
|
120
|
+
predictions = eval_pred.predictions
|
|
121
|
+
labels = eval_pred.label_ids
|
|
122
|
+
|
|
123
|
+
# Handle predictions shape: (n_samples, 1) -> (n_samples,)
|
|
124
|
+
if predictions.ndim == 2 and predictions.shape[1] == 1:
|
|
125
|
+
predictions = predictions.squeeze(1)
|
|
126
|
+
elif predictions.ndim > 2:
|
|
127
|
+
# Flatten if needed
|
|
128
|
+
predictions = predictions.flatten()
|
|
129
|
+
|
|
130
|
+
# Ensure labels are 1D
|
|
131
|
+
if labels.ndim > 1:
|
|
132
|
+
labels = labels.flatten()
|
|
133
|
+
|
|
134
|
+
# Compute metrics
|
|
135
|
+
mse = mse_metric.compute(predictions=predictions, references=labels)["mse"]
|
|
136
|
+
mae = mae_metric.compute(predictions=predictions, references=labels)["mae"]
|
|
137
|
+
|
|
138
|
+
# Compute R² manually (evaluate library doesn't have r2)
|
|
139
|
+
ss_res = np.sum((labels - predictions) ** 2)
|
|
140
|
+
ss_tot = np.sum((labels - np.mean(labels)) ** 2)
|
|
141
|
+
r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
|
|
142
|
+
|
|
143
|
+
return {
|
|
144
|
+
"mse": mse,
|
|
145
|
+
"mae": mae,
|
|
146
|
+
"r2": r2,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def compute_multiclass_metrics(
|
|
151
|
+
eval_pred: EvalPrediction, num_labels: int
|
|
152
|
+
) -> dict[str, float]:
|
|
153
|
+
"""Compute metrics for multi-class classification tasks.
|
|
154
|
+
|
|
155
|
+
Uses HuggingFace evaluate library for accuracy, precision, recall, and F1.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
eval_pred : EvalPrediction
|
|
160
|
+
EvalPrediction object with predictions and label_ids attributes.
|
|
161
|
+
predictions: array of shape (n_samples, n_classes) with logits
|
|
162
|
+
label_ids: array of shape (n_samples,) with true labels
|
|
163
|
+
num_labels : int
|
|
164
|
+
Number of classes.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
dict[str, float]
|
|
169
|
+
Dictionary with accuracy, precision, recall, and f1 metrics.
|
|
170
|
+
|
|
171
|
+
Examples
|
|
172
|
+
--------
|
|
173
|
+
>>> from transformers import EvalPrediction
|
|
174
|
+
>>> import numpy as np
|
|
175
|
+
>>> predictions = np.array([[0.1, 0.8, 0.1], [0.7, 0.2, 0.1]]) # Logits
|
|
176
|
+
>>> labels = np.array([1, 0])
|
|
177
|
+
>>> eval_pred = EvalPrediction(predictions=predictions, label_ids=labels)
|
|
178
|
+
>>> metrics = compute_multiclass_metrics(eval_pred, num_labels=3)
|
|
179
|
+
>>> "accuracy" in metrics
|
|
180
|
+
True
|
|
181
|
+
"""
|
|
182
|
+
# Load metrics
|
|
183
|
+
accuracy_metric = evaluate.load("accuracy")
|
|
184
|
+
precision_metric = evaluate.load("precision")
|
|
185
|
+
recall_metric = evaluate.load("recall")
|
|
186
|
+
f1_metric = evaluate.load("f1")
|
|
187
|
+
|
|
188
|
+
# Extract predictions and labels
|
|
189
|
+
predictions = eval_pred.predictions
|
|
190
|
+
labels = eval_pred.label_ids
|
|
191
|
+
|
|
192
|
+
# Convert logits to predictions
|
|
193
|
+
if predictions.ndim == 1:
|
|
194
|
+
# Single logit per sample (shouldn't happen for multi-class)
|
|
195
|
+
preds = predictions.astype(int)
|
|
196
|
+
else:
|
|
197
|
+
# Multiple logits: take argmax
|
|
198
|
+
preds = np.argmax(predictions, axis=-1)
|
|
199
|
+
|
|
200
|
+
# Ensure labels are integers
|
|
201
|
+
labels = labels.astype(int)
|
|
202
|
+
|
|
203
|
+
# Compute metrics with macro averaging
|
|
204
|
+
accuracy = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
|
|
205
|
+
precision = precision_metric.compute(
|
|
206
|
+
predictions=preds,
|
|
207
|
+
references=labels,
|
|
208
|
+
average="macro",
|
|
209
|
+
zero_division=0,
|
|
210
|
+
)["precision"]
|
|
211
|
+
recall = recall_metric.compute(
|
|
212
|
+
predictions=preds,
|
|
213
|
+
references=labels,
|
|
214
|
+
average="macro",
|
|
215
|
+
zero_division=0,
|
|
216
|
+
)["recall"]
|
|
217
|
+
f1 = f1_metric.compute(
|
|
218
|
+
predictions=preds,
|
|
219
|
+
references=labels,
|
|
220
|
+
average="macro",
|
|
221
|
+
zero_division=0,
|
|
222
|
+
)["f1"]
|
|
223
|
+
|
|
224
|
+
return {
|
|
225
|
+
"accuracy": accuracy,
|
|
226
|
+
"precision": precision,
|
|
227
|
+
"recall": recall,
|
|
228
|
+
"f1": f1,
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def compute_cloze_metrics(
|
|
233
|
+
eval_pred: EvalPrediction, tokenizer: PreTrainedTokenizerBase
|
|
234
|
+
) -> dict[str, float]:
|
|
235
|
+
"""Compute metrics for cloze (MLM) tasks.
|
|
236
|
+
|
|
237
|
+
Computes token-level metrics at masked positions:
|
|
238
|
+
- accuracy: Whether predicted token exactly matches target
|
|
239
|
+
- top_3_accuracy: Whether target is in top 3 predictions
|
|
240
|
+
- top_5_accuracy: Whether target is in top 5 predictions
|
|
241
|
+
- perplexity: Exponentiated average cross-entropy at masked positions
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
eval_pred : EvalPrediction
|
|
246
|
+
EvalPrediction object with:
|
|
247
|
+
- predictions: array of shape (n_samples, seq_len, vocab_size) with logits
|
|
248
|
+
- label_ids: array of shape (n_samples, seq_len) with target_token_ids at
|
|
249
|
+
masked positions, -100 elsewhere (HuggingFace ignore index)
|
|
250
|
+
tokenizer : PreTrainedTokenizerBase
|
|
251
|
+
HuggingFace tokenizer. Used for type checking and potential future extensions.
|
|
252
|
+
|
|
253
|
+
Returns
|
|
254
|
+
-------
|
|
255
|
+
dict[str, float]
|
|
256
|
+
Dictionary with accuracy, top_3_accuracy, top_5_accuracy, and perplexity.
|
|
257
|
+
|
|
258
|
+
Notes
|
|
259
|
+
-----
|
|
260
|
+
This function expects labels encoded in HuggingFace's MLM convention:
|
|
261
|
+
- Target token IDs at positions to evaluate
|
|
262
|
+
- -100 (ignore index) at all other positions
|
|
263
|
+
|
|
264
|
+
The ClozeMLMTrainer's prediction_step() creates this encoding from
|
|
265
|
+
masked_positions and target_token_ids in the dataset.
|
|
266
|
+
|
|
267
|
+
Examples
|
|
268
|
+
--------
|
|
269
|
+
>>> from transformers import EvalPrediction, AutoTokenizer
|
|
270
|
+
>>> import numpy as np
|
|
271
|
+
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
272
|
+
>>> # Simulate: 2 samples, 5 positions, 100 vocab (simplified)
|
|
273
|
+
>>> predictions = np.zeros((2, 5, 100))
|
|
274
|
+
>>> predictions[0, 2, 42] = 10.0 # High logit for token 42 at pos 2
|
|
275
|
+
>>> predictions[1, 1, 17] = 10.0 # High logit for token 17 at pos 1
|
|
276
|
+
>>> labels = np.full((2, 5), -100)
|
|
277
|
+
>>> labels[0, 2] = 42 # Target at pos 2
|
|
278
|
+
>>> labels[1, 1] = 17 # Target at pos 1
|
|
279
|
+
>>> eval_pred = EvalPrediction(predictions=predictions, label_ids=labels)
|
|
280
|
+
>>> metrics = compute_cloze_metrics(eval_pred, tokenizer)
|
|
281
|
+
>>> metrics["accuracy"]
|
|
282
|
+
1.0
|
|
283
|
+
"""
|
|
284
|
+
predictions = eval_pred.predictions
|
|
285
|
+
labels = eval_pred.label_ids
|
|
286
|
+
|
|
287
|
+
# Handle empty or invalid inputs
|
|
288
|
+
if predictions is None or predictions.size == 0:
|
|
289
|
+
return {
|
|
290
|
+
"accuracy": 0.0,
|
|
291
|
+
"top_3_accuracy": 0.0,
|
|
292
|
+
"top_5_accuracy": 0.0,
|
|
293
|
+
"perplexity": float("inf"),
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
if labels is None:
|
|
297
|
+
return {
|
|
298
|
+
"accuracy": 0.0,
|
|
299
|
+
"top_3_accuracy": 0.0,
|
|
300
|
+
"top_5_accuracy": 0.0,
|
|
301
|
+
"perplexity": float("inf"),
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
# Validate shapes
|
|
305
|
+
if predictions.ndim != 3:
|
|
306
|
+
# Unexpected shape, return defaults
|
|
307
|
+
return {
|
|
308
|
+
"accuracy": 0.0,
|
|
309
|
+
"top_3_accuracy": 0.0,
|
|
310
|
+
"top_5_accuracy": 0.0,
|
|
311
|
+
"perplexity": float("inf"),
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
if labels.ndim != 2:
|
|
315
|
+
return {
|
|
316
|
+
"accuracy": 0.0,
|
|
317
|
+
"top_3_accuracy": 0.0,
|
|
318
|
+
"top_5_accuracy": 0.0,
|
|
319
|
+
"perplexity": float("inf"),
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
# Check shape compatibility
|
|
323
|
+
if predictions.shape[:2] != labels.shape:
|
|
324
|
+
return {
|
|
325
|
+
"accuracy": 0.0,
|
|
326
|
+
"top_3_accuracy": 0.0,
|
|
327
|
+
"top_5_accuracy": 0.0,
|
|
328
|
+
"perplexity": float("inf"),
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
# Find masked positions (where label != -100)
|
|
332
|
+
mask = labels != -100
|
|
333
|
+
|
|
334
|
+
# Handle case with no masked positions
|
|
335
|
+
if not mask.any():
|
|
336
|
+
return {
|
|
337
|
+
"accuracy": 0.0,
|
|
338
|
+
"top_3_accuracy": 0.0,
|
|
339
|
+
"top_5_accuracy": 0.0,
|
|
340
|
+
"perplexity": float("inf"),
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
n_total = int(mask.sum())
|
|
344
|
+
|
|
345
|
+
# Compute top-1 accuracy
|
|
346
|
+
pred_tokens = np.argmax(predictions, axis=-1) # (n_samples, seq_len)
|
|
347
|
+
correct = (pred_tokens == labels) & mask
|
|
348
|
+
n_correct = int(correct.sum())
|
|
349
|
+
accuracy = float(n_correct) / float(n_total)
|
|
350
|
+
|
|
351
|
+
# Compute top-k accuracy using argpartition (efficient for large vocab)
|
|
352
|
+
def compute_topk_accuracy(k: int) -> float:
|
|
353
|
+
"""Compute top-k accuracy at masked positions."""
|
|
354
|
+
vocab_size = predictions.shape[2]
|
|
355
|
+
if k >= vocab_size:
|
|
356
|
+
# All tokens are in top-k
|
|
357
|
+
return 1.0
|
|
358
|
+
|
|
359
|
+
# Get top-k indices: shape (n_samples, seq_len, k)
|
|
360
|
+
topk_indices = np.argpartition(predictions, -k, axis=-1)[..., -k:]
|
|
361
|
+
|
|
362
|
+
# Expand labels for comparison: (n_samples, seq_len, 1)
|
|
363
|
+
labels_expanded = labels[..., np.newaxis]
|
|
364
|
+
|
|
365
|
+
# Check if label is in top-k for each position
|
|
366
|
+
in_topk = (topk_indices == labels_expanded).any(axis=-1)
|
|
367
|
+
|
|
368
|
+
# Apply mask and compute accuracy
|
|
369
|
+
correct_topk = in_topk & mask
|
|
370
|
+
n_correct_k = int(correct_topk.sum())
|
|
371
|
+
return float(n_correct_k) / float(n_total)
|
|
372
|
+
|
|
373
|
+
top_3_accuracy = compute_topk_accuracy(3)
|
|
374
|
+
top_5_accuracy = compute_topk_accuracy(5)
|
|
375
|
+
|
|
376
|
+
# Compute perplexity
|
|
377
|
+
# Perplexity = exp(average cross-entropy loss)
|
|
378
|
+
def compute_perplexity() -> float:
|
|
379
|
+
"""Compute perplexity at masked positions."""
|
|
380
|
+
# Numerically stable softmax using log-sum-exp trick
|
|
381
|
+
max_logits = predictions.max(axis=-1, keepdims=True)
|
|
382
|
+
shifted = predictions - max_logits
|
|
383
|
+
exp_logits = np.exp(shifted)
|
|
384
|
+
sum_exp = exp_logits.sum(axis=-1, keepdims=True)
|
|
385
|
+
log_probs = shifted - np.log(sum_exp) # log softmax
|
|
386
|
+
|
|
387
|
+
# Get log probabilities at label positions
|
|
388
|
+
n_samples, seq_len, _ = predictions.shape
|
|
389
|
+
|
|
390
|
+
# Create indices for gathering
|
|
391
|
+
batch_indices = np.arange(n_samples)[:, np.newaxis]
|
|
392
|
+
seq_indices = np.arange(seq_len)[np.newaxis, :]
|
|
393
|
+
|
|
394
|
+
# Handle -100 labels by replacing with 0 temporarily (they'll be masked out)
|
|
395
|
+
safe_labels = np.where(labels >= 0, labels, 0)
|
|
396
|
+
|
|
397
|
+
# Gather log probs: log_probs[i, j, labels[i, j]]
|
|
398
|
+
target_log_probs = log_probs[batch_indices, seq_indices, safe_labels]
|
|
399
|
+
|
|
400
|
+
# Cross-entropy is negative log prob
|
|
401
|
+
cross_entropy = -target_log_probs # (n_samples, seq_len)
|
|
402
|
+
|
|
403
|
+
# Average over masked positions only
|
|
404
|
+
masked_ce = cross_entropy[mask]
|
|
405
|
+
if len(masked_ce) == 0:
|
|
406
|
+
return float("inf")
|
|
407
|
+
|
|
408
|
+
avg_ce = float(masked_ce.mean())
|
|
409
|
+
|
|
410
|
+
# Perplexity = exp(average cross-entropy)
|
|
411
|
+
# Clip to avoid overflow
|
|
412
|
+
if avg_ce > 100:
|
|
413
|
+
return float("inf")
|
|
414
|
+
|
|
415
|
+
return float(np.exp(avg_ce))
|
|
416
|
+
|
|
417
|
+
perplexity = compute_perplexity()
|
|
418
|
+
|
|
419
|
+
return {
|
|
420
|
+
"accuracy": accuracy,
|
|
421
|
+
"top_3_accuracy": top_3_accuracy,
|
|
422
|
+
"top_5_accuracy": top_5_accuracy,
|
|
423
|
+
"perplexity": perplexity,
|
|
424
|
+
}
|