bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,424 @@
1
+ """Metrics computation using HuggingFace evaluate library.
2
+
3
+ This module provides metric computation functions for use with HuggingFace Trainer.
4
+ It uses the evaluate library for standardized, well-tested metrics.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ import evaluate
12
+ import numpy as np
13
+
14
+ if TYPE_CHECKING:
15
+ from transformers import EvalPrediction, PreTrainedTokenizerBase
16
+
17
+
18
+ def compute_binary_metrics(eval_pred: EvalPrediction) -> dict[str, float]:
19
+ """Compute metrics for binary classification tasks.
20
+
21
+ Uses HuggingFace evaluate library for accuracy, precision, recall, and F1.
22
+
23
+ Parameters
24
+ ----------
25
+ eval_pred : EvalPrediction
26
+ EvalPrediction object with predictions and label_ids attributes.
27
+ predictions: array of shape (n_samples,) with logits
28
+ label_ids: array of shape (n_samples,) with true labels (0 or 1)
29
+
30
+ Returns
31
+ -------
32
+ dict[str, float]
33
+ Dictionary with accuracy, precision, recall, and f1 metrics.
34
+
35
+ Examples
36
+ --------
37
+ >>> from transformers import EvalPrediction
38
+ >>> import numpy as np
39
+ >>> predictions = np.array([0.8, 0.3, 0.9, 0.2]) # Logits
40
+ >>> labels = np.array([1.0, 0.0, 1.0, 0.0])
41
+ >>> eval_pred = EvalPrediction(predictions=predictions, label_ids=labels)
42
+ >>> metrics = compute_binary_metrics(eval_pred)
43
+ >>> "accuracy" in metrics
44
+ True
45
+ """
46
+ # Load metrics from evaluate library
47
+ accuracy_metric = evaluate.load("accuracy")
48
+ precision_metric = evaluate.load("precision")
49
+ recall_metric = evaluate.load("recall")
50
+ f1_metric = evaluate.load("f1")
51
+
52
+ # Extract predictions and labels
53
+ predictions = eval_pred.predictions
54
+ labels = eval_pred.label_ids
55
+
56
+ # Convert logits to predictions (binary: apply sigmoid and threshold)
57
+ if predictions.ndim == 1:
58
+ # Single logit per sample
59
+ preds = (1 / (1 + np.exp(-predictions)) > 0.5).astype(int)
60
+ else:
61
+ # Multiple logits (shouldn't happen for binary, but handle it)
62
+ preds = np.argmax(predictions, axis=-1)
63
+
64
+ # Ensure labels are integers
65
+ labels = labels.astype(int)
66
+
67
+ # Compute metrics
68
+ accuracy = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
69
+ precision = precision_metric.compute(
70
+ predictions=preds, references=labels, average="binary", zero_division=0
71
+ )["precision"]
72
+ recall = recall_metric.compute(
73
+ predictions=preds, references=labels, average="binary", zero_division=0
74
+ )["recall"]
75
+ f1 = f1_metric.compute(
76
+ predictions=preds, references=labels, average="binary", zero_division=0
77
+ )["f1"]
78
+
79
+ return {
80
+ "accuracy": accuracy,
81
+ "precision": precision,
82
+ "recall": recall,
83
+ "f1": f1,
84
+ }
85
+
86
+
87
+ def compute_regression_metrics(eval_pred: EvalPrediction) -> dict[str, float]:
88
+ """Compute metrics for regression tasks.
89
+
90
+ Uses HuggingFace evaluate library for MSE, MAE, and R².
91
+
92
+ Parameters
93
+ ----------
94
+ eval_pred : EvalPrediction
95
+ EvalPrediction object with predictions and label_ids attributes.
96
+ predictions: array of shape (n_samples, 1) with continuous values
97
+ label_ids: array of shape (n_samples,) with true continuous values
98
+
99
+ Returns
100
+ -------
101
+ dict[str, float]
102
+ Dictionary with mse, mae, and r2 metrics.
103
+
104
+ Examples
105
+ --------
106
+ >>> from transformers import EvalPrediction
107
+ >>> import numpy as np
108
+ >>> predictions = np.array([[250.5], [300.2], [275.0]]) # Continuous values
109
+ >>> labels = np.array([250.0, 300.0, 275.0])
110
+ >>> eval_pred = EvalPrediction(predictions=predictions, label_ids=labels)
111
+ >>> metrics = compute_regression_metrics(eval_pred)
112
+ >>> "mse" in metrics
113
+ True
114
+ """
115
+ # Load metrics from evaluate library
116
+ mse_metric = evaluate.load("mse")
117
+ mae_metric = evaluate.load("mae")
118
+
119
+ # Extract predictions and labels
120
+ predictions = eval_pred.predictions
121
+ labels = eval_pred.label_ids
122
+
123
+ # Handle predictions shape: (n_samples, 1) -> (n_samples,)
124
+ if predictions.ndim == 2 and predictions.shape[1] == 1:
125
+ predictions = predictions.squeeze(1)
126
+ elif predictions.ndim > 2:
127
+ # Flatten if needed
128
+ predictions = predictions.flatten()
129
+
130
+ # Ensure labels are 1D
131
+ if labels.ndim > 1:
132
+ labels = labels.flatten()
133
+
134
+ # Compute metrics
135
+ mse = mse_metric.compute(predictions=predictions, references=labels)["mse"]
136
+ mae = mae_metric.compute(predictions=predictions, references=labels)["mae"]
137
+
138
+ # Compute R² manually (evaluate library doesn't have r2)
139
+ ss_res = np.sum((labels - predictions) ** 2)
140
+ ss_tot = np.sum((labels - np.mean(labels)) ** 2)
141
+ r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
142
+
143
+ return {
144
+ "mse": mse,
145
+ "mae": mae,
146
+ "r2": r2,
147
+ }
148
+
149
+
150
+ def compute_multiclass_metrics(
151
+ eval_pred: EvalPrediction, num_labels: int
152
+ ) -> dict[str, float]:
153
+ """Compute metrics for multi-class classification tasks.
154
+
155
+ Uses HuggingFace evaluate library for accuracy, precision, recall, and F1.
156
+
157
+ Parameters
158
+ ----------
159
+ eval_pred : EvalPrediction
160
+ EvalPrediction object with predictions and label_ids attributes.
161
+ predictions: array of shape (n_samples, n_classes) with logits
162
+ label_ids: array of shape (n_samples,) with true labels
163
+ num_labels : int
164
+ Number of classes.
165
+
166
+ Returns
167
+ -------
168
+ dict[str, float]
169
+ Dictionary with accuracy, precision, recall, and f1 metrics.
170
+
171
+ Examples
172
+ --------
173
+ >>> from transformers import EvalPrediction
174
+ >>> import numpy as np
175
+ >>> predictions = np.array([[0.1, 0.8, 0.1], [0.7, 0.2, 0.1]]) # Logits
176
+ >>> labels = np.array([1, 0])
177
+ >>> eval_pred = EvalPrediction(predictions=predictions, label_ids=labels)
178
+ >>> metrics = compute_multiclass_metrics(eval_pred, num_labels=3)
179
+ >>> "accuracy" in metrics
180
+ True
181
+ """
182
+ # Load metrics
183
+ accuracy_metric = evaluate.load("accuracy")
184
+ precision_metric = evaluate.load("precision")
185
+ recall_metric = evaluate.load("recall")
186
+ f1_metric = evaluate.load("f1")
187
+
188
+ # Extract predictions and labels
189
+ predictions = eval_pred.predictions
190
+ labels = eval_pred.label_ids
191
+
192
+ # Convert logits to predictions
193
+ if predictions.ndim == 1:
194
+ # Single logit per sample (shouldn't happen for multi-class)
195
+ preds = predictions.astype(int)
196
+ else:
197
+ # Multiple logits: take argmax
198
+ preds = np.argmax(predictions, axis=-1)
199
+
200
+ # Ensure labels are integers
201
+ labels = labels.astype(int)
202
+
203
+ # Compute metrics with macro averaging
204
+ accuracy = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
205
+ precision = precision_metric.compute(
206
+ predictions=preds,
207
+ references=labels,
208
+ average="macro",
209
+ zero_division=0,
210
+ )["precision"]
211
+ recall = recall_metric.compute(
212
+ predictions=preds,
213
+ references=labels,
214
+ average="macro",
215
+ zero_division=0,
216
+ )["recall"]
217
+ f1 = f1_metric.compute(
218
+ predictions=preds,
219
+ references=labels,
220
+ average="macro",
221
+ zero_division=0,
222
+ )["f1"]
223
+
224
+ return {
225
+ "accuracy": accuracy,
226
+ "precision": precision,
227
+ "recall": recall,
228
+ "f1": f1,
229
+ }
230
+
231
+
232
+ def compute_cloze_metrics(
233
+ eval_pred: EvalPrediction, tokenizer: PreTrainedTokenizerBase
234
+ ) -> dict[str, float]:
235
+ """Compute metrics for cloze (MLM) tasks.
236
+
237
+ Computes token-level metrics at masked positions:
238
+ - accuracy: Whether predicted token exactly matches target
239
+ - top_3_accuracy: Whether target is in top 3 predictions
240
+ - top_5_accuracy: Whether target is in top 5 predictions
241
+ - perplexity: Exponentiated average cross-entropy at masked positions
242
+
243
+ Parameters
244
+ ----------
245
+ eval_pred : EvalPrediction
246
+ EvalPrediction object with:
247
+ - predictions: array of shape (n_samples, seq_len, vocab_size) with logits
248
+ - label_ids: array of shape (n_samples, seq_len) with target_token_ids at
249
+ masked positions, -100 elsewhere (HuggingFace ignore index)
250
+ tokenizer : PreTrainedTokenizerBase
251
+ HuggingFace tokenizer. Used for type checking and potential future extensions.
252
+
253
+ Returns
254
+ -------
255
+ dict[str, float]
256
+ Dictionary with accuracy, top_3_accuracy, top_5_accuracy, and perplexity.
257
+
258
+ Notes
259
+ -----
260
+ This function expects labels encoded in HuggingFace's MLM convention:
261
+ - Target token IDs at positions to evaluate
262
+ - -100 (ignore index) at all other positions
263
+
264
+ The ClozeMLMTrainer's prediction_step() creates this encoding from
265
+ masked_positions and target_token_ids in the dataset.
266
+
267
+ Examples
268
+ --------
269
+ >>> from transformers import EvalPrediction, AutoTokenizer
270
+ >>> import numpy as np
271
+ >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
272
+ >>> # Simulate: 2 samples, 5 positions, 100 vocab (simplified)
273
+ >>> predictions = np.zeros((2, 5, 100))
274
+ >>> predictions[0, 2, 42] = 10.0 # High logit for token 42 at pos 2
275
+ >>> predictions[1, 1, 17] = 10.0 # High logit for token 17 at pos 1
276
+ >>> labels = np.full((2, 5), -100)
277
+ >>> labels[0, 2] = 42 # Target at pos 2
278
+ >>> labels[1, 1] = 17 # Target at pos 1
279
+ >>> eval_pred = EvalPrediction(predictions=predictions, label_ids=labels)
280
+ >>> metrics = compute_cloze_metrics(eval_pred, tokenizer)
281
+ >>> metrics["accuracy"]
282
+ 1.0
283
+ """
284
+ predictions = eval_pred.predictions
285
+ labels = eval_pred.label_ids
286
+
287
+ # Handle empty or invalid inputs
288
+ if predictions is None or predictions.size == 0:
289
+ return {
290
+ "accuracy": 0.0,
291
+ "top_3_accuracy": 0.0,
292
+ "top_5_accuracy": 0.0,
293
+ "perplexity": float("inf"),
294
+ }
295
+
296
+ if labels is None:
297
+ return {
298
+ "accuracy": 0.0,
299
+ "top_3_accuracy": 0.0,
300
+ "top_5_accuracy": 0.0,
301
+ "perplexity": float("inf"),
302
+ }
303
+
304
+ # Validate shapes
305
+ if predictions.ndim != 3:
306
+ # Unexpected shape, return defaults
307
+ return {
308
+ "accuracy": 0.0,
309
+ "top_3_accuracy": 0.0,
310
+ "top_5_accuracy": 0.0,
311
+ "perplexity": float("inf"),
312
+ }
313
+
314
+ if labels.ndim != 2:
315
+ return {
316
+ "accuracy": 0.0,
317
+ "top_3_accuracy": 0.0,
318
+ "top_5_accuracy": 0.0,
319
+ "perplexity": float("inf"),
320
+ }
321
+
322
+ # Check shape compatibility
323
+ if predictions.shape[:2] != labels.shape:
324
+ return {
325
+ "accuracy": 0.0,
326
+ "top_3_accuracy": 0.0,
327
+ "top_5_accuracy": 0.0,
328
+ "perplexity": float("inf"),
329
+ }
330
+
331
+ # Find masked positions (where label != -100)
332
+ mask = labels != -100
333
+
334
+ # Handle case with no masked positions
335
+ if not mask.any():
336
+ return {
337
+ "accuracy": 0.0,
338
+ "top_3_accuracy": 0.0,
339
+ "top_5_accuracy": 0.0,
340
+ "perplexity": float("inf"),
341
+ }
342
+
343
+ n_total = int(mask.sum())
344
+
345
+ # Compute top-1 accuracy
346
+ pred_tokens = np.argmax(predictions, axis=-1) # (n_samples, seq_len)
347
+ correct = (pred_tokens == labels) & mask
348
+ n_correct = int(correct.sum())
349
+ accuracy = float(n_correct) / float(n_total)
350
+
351
+ # Compute top-k accuracy using argpartition (efficient for large vocab)
352
+ def compute_topk_accuracy(k: int) -> float:
353
+ """Compute top-k accuracy at masked positions."""
354
+ vocab_size = predictions.shape[2]
355
+ if k >= vocab_size:
356
+ # All tokens are in top-k
357
+ return 1.0
358
+
359
+ # Get top-k indices: shape (n_samples, seq_len, k)
360
+ topk_indices = np.argpartition(predictions, -k, axis=-1)[..., -k:]
361
+
362
+ # Expand labels for comparison: (n_samples, seq_len, 1)
363
+ labels_expanded = labels[..., np.newaxis]
364
+
365
+ # Check if label is in top-k for each position
366
+ in_topk = (topk_indices == labels_expanded).any(axis=-1)
367
+
368
+ # Apply mask and compute accuracy
369
+ correct_topk = in_topk & mask
370
+ n_correct_k = int(correct_topk.sum())
371
+ return float(n_correct_k) / float(n_total)
372
+
373
+ top_3_accuracy = compute_topk_accuracy(3)
374
+ top_5_accuracy = compute_topk_accuracy(5)
375
+
376
+ # Compute perplexity
377
+ # Perplexity = exp(average cross-entropy loss)
378
+ def compute_perplexity() -> float:
379
+ """Compute perplexity at masked positions."""
380
+ # Numerically stable softmax using log-sum-exp trick
381
+ max_logits = predictions.max(axis=-1, keepdims=True)
382
+ shifted = predictions - max_logits
383
+ exp_logits = np.exp(shifted)
384
+ sum_exp = exp_logits.sum(axis=-1, keepdims=True)
385
+ log_probs = shifted - np.log(sum_exp) # log softmax
386
+
387
+ # Get log probabilities at label positions
388
+ n_samples, seq_len, _ = predictions.shape
389
+
390
+ # Create indices for gathering
391
+ batch_indices = np.arange(n_samples)[:, np.newaxis]
392
+ seq_indices = np.arange(seq_len)[np.newaxis, :]
393
+
394
+ # Handle -100 labels by replacing with 0 temporarily (they'll be masked out)
395
+ safe_labels = np.where(labels >= 0, labels, 0)
396
+
397
+ # Gather log probs: log_probs[i, j, labels[i, j]]
398
+ target_log_probs = log_probs[batch_indices, seq_indices, safe_labels]
399
+
400
+ # Cross-entropy is negative log prob
401
+ cross_entropy = -target_log_probs # (n_samples, seq_len)
402
+
403
+ # Average over masked positions only
404
+ masked_ce = cross_entropy[mask]
405
+ if len(masked_ce) == 0:
406
+ return float("inf")
407
+
408
+ avg_ce = float(masked_ce.mean())
409
+
410
+ # Perplexity = exp(average cross-entropy)
411
+ # Clip to avoid overflow
412
+ if avg_ce > 100:
413
+ return float("inf")
414
+
415
+ return float(np.exp(avg_ce))
416
+
417
+ perplexity = compute_perplexity()
418
+
419
+ return {
420
+ "accuracy": accuracy,
421
+ "top_3_accuracy": top_3_accuracy,
422
+ "top_5_accuracy": top_5_accuracy,
423
+ "perplexity": perplexity,
424
+ }