bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,862 @@
1
+ """Cloze model for fill-in-the-blank tasks with GLMM support.
2
+
3
+ Implements masked language modeling with participant-level random effects for
4
+ predicting tokens at unfilled slots in partially-filled templates. Supports
5
+ three modes: fixed effects, random intercepts, random slopes.
6
+
7
+ Architecture: Masked LM (BERT/RoBERTa) for token prediction
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import copy
13
+ import tempfile
14
+ from pathlib import Path
15
+
16
+ import numpy as np
17
+ import torch
18
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, TrainingArguments
19
+
20
+ from bead.active_learning.config import VarianceComponents
21
+ from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
22
+ from bead.active_learning.models.random_effects import RandomEffectsManager
23
+ from bead.active_learning.trainers.data_collator import ClozeDataCollator
24
+ from bead.active_learning.trainers.dataset_utils import cloze_items_to_dataset
25
+ from bead.config.active_learning import ClozeModelConfig
26
+ from bead.items.item import Item
27
+ from bead.items.item_template import ItemTemplate, TaskType
28
+
29
+ __all__ = ["ClozeModel"]
30
+
31
+
32
+ class ClozeModel(ActiveLearningModel):
33
+ """Model for cloze tasks with participant-level random effects.
34
+
35
+ Uses masked language modeling (BERT/RoBERTa) to predict tokens at unfilled
36
+ slots in partially-filled templates. Supports three GLMM modes:
37
+ - Fixed effects: Standard MLM
38
+ - Random intercepts: Participant-specific bias on output logits
39
+ - Random slopes: Participant-specific MLM heads
40
+
41
+ Parameters
42
+ ----------
43
+ config : ClozeModelConfig
44
+ Configuration object containing all model parameters.
45
+
46
+ Attributes
47
+ ----------
48
+ config : ClozeModelConfig
49
+ Model configuration.
50
+ tokenizer : AutoTokenizer
51
+ Masked LM tokenizer.
52
+ model : AutoModelForMaskedLM
53
+ Masked language model (BERT or RoBERTa).
54
+ encoder : nn.Module
55
+ Encoder module from the model.
56
+ mlm_head : nn.Module
57
+ MLM prediction head.
58
+ random_effects : RandomEffectsManager
59
+ Manager for participant-level random effects.
60
+ variance_history : list[VarianceComponents]
61
+ Variance component estimates over training.
62
+ _is_fitted : bool
63
+ Whether model has been trained.
64
+
65
+ Examples
66
+ --------
67
+ >>> from uuid import uuid4
68
+ >>> from bead.items.item import Item, UnfilledSlot
69
+ >>> from bead.config.active_learning import ClozeModelConfig
70
+ >>> items = [
71
+ ... Item(
72
+ ... item_template_id=uuid4(),
73
+ ... rendered_elements={"text": "The cat ___."},
74
+ ... unfilled_slots=[
75
+ ... UnfilledSlot(slot_name="verb", position=2, constraint_ids=[])
76
+ ... ]
77
+ ... )
78
+ ... for _ in range(6)
79
+ ... ]
80
+ >>> labels = [["ran"], ["jumped"], ["slept"]] * 2 # One token per unfilled slot
81
+ >>> config = ClozeModelConfig( # doctest: +SKIP
82
+ ... num_epochs=1, batch_size=2, device="cpu"
83
+ ... )
84
+ >>> model = ClozeModel(config=config) # doctest: +SKIP
85
+ >>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ config: ClozeModelConfig | None = None,
91
+ ) -> None:
92
+ """Initialize cloze model.
93
+
94
+ Parameters
95
+ ----------
96
+ config : ClozeModelConfig | None
97
+ Configuration object. If None, uses default configuration.
98
+ """
99
+ self.config = config or ClozeModelConfig()
100
+
101
+ # Validate mixed_effects configuration
102
+ super().__init__(self.config)
103
+
104
+ # Load tokenizer and masked LM model
105
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
106
+ self.model = AutoModelForMaskedLM.from_pretrained(self.config.model_name)
107
+
108
+ # Extract encoder and MLM head
109
+ # BERT-style models use model.bert and model.cls
110
+ # RoBERTa-style models use model.roberta and model.lm_head
111
+ if hasattr(self.model, "bert"):
112
+ self.encoder = self.model.bert
113
+ self.mlm_head = self.model.cls
114
+ elif hasattr(self.model, "roberta"):
115
+ self.encoder = self.model.roberta
116
+ self.mlm_head = self.model.lm_head
117
+ else:
118
+ # Fallback: try to use the base model attribute
119
+ self.encoder = self.model.base_model
120
+ self.mlm_head = self.model.lm_head
121
+
122
+ self._is_fitted = False
123
+
124
+ # Initialize random effects manager (created during training)
125
+ self.random_effects: RandomEffectsManager | None = None
126
+ self.variance_history: list[VarianceComponents] = []
127
+
128
+ self.model.to(self.config.device)
129
+
130
+ @property
131
+ def supported_task_types(self) -> list[TaskType]:
132
+ """Get supported task types.
133
+
134
+ Returns
135
+ -------
136
+ list[TaskType]
137
+ List containing "cloze".
138
+ """
139
+ return ["cloze"]
140
+
141
+ def validate_item_compatibility(
142
+ self, item: Item, item_template: ItemTemplate
143
+ ) -> None:
144
+ """Validate item is compatible with cloze model.
145
+
146
+ Parameters
147
+ ----------
148
+ item : Item
149
+ Item to validate.
150
+ item_template : ItemTemplate
151
+ Template the item was constructed from.
152
+
153
+ Raises
154
+ ------
155
+ ValueError
156
+ If task_type is not "cloze".
157
+ ValueError
158
+ If item has no unfilled_slots.
159
+ """
160
+ if item_template.task_type != "cloze":
161
+ raise ValueError(
162
+ f"Expected task_type 'cloze', got '{item_template.task_type}'"
163
+ )
164
+
165
+ if not item.unfilled_slots:
166
+ raise ValueError(
167
+ "Cloze items must have at least one unfilled slot. "
168
+ f"Item {item.id} has no unfilled_slots."
169
+ )
170
+
171
+ def _prepare_inputs_and_masks(
172
+ self, items: list[Item]
173
+ ) -> tuple[dict[str, torch.Tensor], list[list[int]]]:
174
+ """Prepare tokenized inputs with masked positions.
175
+
176
+ Extracts text from items, tokenizes, and replaces tokens at unfilled_slots
177
+ positions with [MASK] token.
178
+
179
+ Parameters
180
+ ----------
181
+ items : list[Item]
182
+ Items to prepare.
183
+
184
+ Returns
185
+ -------
186
+ tuple[dict[str, torch.Tensor], list[list[int]]]
187
+ - Tokenized inputs (input_ids, attention_mask)
188
+ - List of masked token positions per item (token-level indices)
189
+ """
190
+ texts = []
191
+ n_slots_per_item = []
192
+
193
+ for item in items:
194
+ # Get rendered text
195
+ text = item.rendered_elements.get("text", "")
196
+ texts.append(text)
197
+ n_slots_per_item.append(len(item.unfilled_slots))
198
+
199
+ # Tokenize all texts
200
+ tokenized = self.tokenizer(
201
+ texts,
202
+ padding=True,
203
+ truncation=True,
204
+ max_length=self.config.max_length,
205
+ return_tensors="pt",
206
+ ).to(self.config.device)
207
+
208
+ mask_token_id = self.tokenizer.mask_token_id
209
+
210
+ # Find and replace "___" placeholders with [MASK]
211
+ # Track ONE position per unfilled slot (even if "___" spans multiple tokens)
212
+ token_masked_positions = []
213
+ for i, text in enumerate(texts):
214
+ # Tokenize individually to find "___" positions
215
+ tokens = self.tokenizer.tokenize(text)
216
+ masked_indices = []
217
+
218
+ # Track which tokens are part of "___" to avoid duplicates
219
+ in_blank = False
220
+ for j, token in enumerate(tokens):
221
+ # Check if this token is part of a "___" placeholder
222
+ if "_" in token and not in_blank:
223
+ # Start of a new blank - record this position
224
+ token_idx = j + 1 # Add 1 for [CLS] token
225
+ masked_indices.append(token_idx)
226
+ in_blank = True
227
+ # Replace with [MASK]
228
+ if token_idx < tokenized["input_ids"].shape[1]:
229
+ tokenized["input_ids"][i, token_idx] = mask_token_id
230
+ elif "_" in token and in_blank:
231
+ # Continuation of current blank - also mask but don't record
232
+ token_idx = j + 1
233
+ if token_idx < tokenized["input_ids"].shape[1]:
234
+ tokenized["input_ids"][i, token_idx] = mask_token_id
235
+ else:
236
+ # Not a blank token - reset in_blank
237
+ in_blank = False
238
+
239
+ # Verify we found the expected number of masked positions
240
+ expected_slots = n_slots_per_item[i]
241
+ if len(masked_indices) != expected_slots:
242
+ raise ValueError(
243
+ f"Mismatch between masked positions and unfilled_slots "
244
+ f"for item {i}: found {len(masked_indices)} '___' "
245
+ f"placeholders in text but item has {expected_slots} "
246
+ f"unfilled_slots. Ensure rendered text uses exactly one "
247
+ f"'___' per unfilled_slot. Text: '{text}'"
248
+ )
249
+
250
+ token_masked_positions.append(masked_indices)
251
+
252
+ return tokenized, token_masked_positions
253
+
254
+ def _prepare_training_data(
255
+ self,
256
+ items: list[Item],
257
+ labels: list[str] | list[list[str]],
258
+ participant_ids: list[str],
259
+ validation_items: list[Item] | None,
260
+ validation_labels: list[str] | list[list[str]] | None,
261
+ ) -> tuple[
262
+ list[Item],
263
+ list[list[str]],
264
+ list[Item] | None,
265
+ list[list[str]] | None,
266
+ ]:
267
+ """Prepare data for training, including validation of label format.
268
+
269
+ Parameters
270
+ ----------
271
+ items : list[Item]
272
+ Training items.
273
+ labels : list[list[str]]
274
+ Training labels as list of lists (one token per unfilled slot).
275
+ participant_ids : list[str]
276
+ Participant IDs (already normalized).
277
+ validation_items : list[Item] | None
278
+ Validation items.
279
+ validation_labels : list[list[str]] | None
280
+ Validation labels.
281
+
282
+ Returns
283
+ -------
284
+ tuple[list[Item], list[list[str]], list[Item] | None, list[list[str]] | None]
285
+ Prepared items, labels, validation items, validation labels.
286
+ """
287
+ # Validate labels format: each label must be a list matching unfilled_slots
288
+ labels_list = list(labels) # Type: list[list[str]]
289
+ for i, (item, label) in enumerate(zip(items, labels_list, strict=True)):
290
+ if not isinstance(label, list):
291
+ raise ValueError(
292
+ f"ClozeModel requires labels to be list[list[str]], "
293
+ f"but got {type(label)} for item {i}"
294
+ )
295
+ if len(label) != len(item.unfilled_slots):
296
+ raise ValueError(
297
+ f"Label length mismatch for item {i}: "
298
+ f"expected {len(item.unfilled_slots)} tokens "
299
+ f"(matching unfilled_slots), got {len(label)} tokens. "
300
+ f"Ensure each label is a list with one token per unfilled slot."
301
+ )
302
+
303
+ val_labels_list: list[list[str]] | None = None
304
+ if validation_items is not None and validation_labels is not None:
305
+ val_labels_list = list(validation_labels) # Type: list[list[str]]
306
+ for i, (item, label) in enumerate(
307
+ zip(validation_items, val_labels_list, strict=True)
308
+ ):
309
+ if not isinstance(label, list):
310
+ raise ValueError(
311
+ f"ClozeModel requires validation_labels to be list[list[str]], "
312
+ f"but got {type(label)} for validation item {i}"
313
+ )
314
+ if len(label) != len(item.unfilled_slots):
315
+ raise ValueError(
316
+ f"Validation label length mismatch for item {i}: "
317
+ f"expected {len(item.unfilled_slots)} tokens, "
318
+ f"got {len(label)} tokens."
319
+ )
320
+
321
+ return items, labels_list, participant_ids, validation_items, val_labels_list
322
+
323
+ def _do_training(
324
+ self,
325
+ items: list[Item],
326
+ labels_numeric: list[list[str]],
327
+ participant_ids: list[str],
328
+ validation_items: list[Item] | None,
329
+ validation_labels_numeric: list[list[str]] | None,
330
+ ) -> dict[str, float]:
331
+ """Perform the actual training logic (HuggingFace Trainer or custom loop).
332
+
333
+ Parameters
334
+ ----------
335
+ items : list[Item]
336
+ Training items.
337
+ labels_numeric : list[list[str]]
338
+ Training labels (already validated).
339
+ participant_ids : list[str]
340
+ Participant IDs.
341
+ validation_items : list[Item] | None
342
+ Validation items.
343
+ validation_labels_numeric : list[list[str]] | None
344
+ Validation labels.
345
+
346
+ Returns
347
+ -------
348
+ dict[str, float]
349
+ Training metrics.
350
+ """
351
+ # Use HuggingFace Trainer for fixed and random_intercepts modes
352
+ # random_slopes requires custom loop due to per-participant MLM heads
353
+ use_huggingface_trainer = self.config.mixed_effects.mode in (
354
+ "fixed",
355
+ "random_intercepts",
356
+ )
357
+
358
+ if use_huggingface_trainer:
359
+ return self._train_with_huggingface_trainer(
360
+ items,
361
+ labels_numeric,
362
+ participant_ids,
363
+ validation_items,
364
+ validation_labels_numeric,
365
+ )
366
+ else:
367
+ # Use custom training loop for random_slopes
368
+ return self._train_with_custom_loop(
369
+ items,
370
+ labels_numeric,
371
+ participant_ids,
372
+ validation_items,
373
+ validation_labels_numeric,
374
+ )
375
+
376
+ def _train_with_huggingface_trainer(
377
+ self,
378
+ items: list[Item],
379
+ labels: list[list[str]],
380
+ participant_ids: list[str],
381
+ validation_items: list[Item] | None,
382
+ validation_labels: list[list[str]] | None,
383
+ ) -> dict[str, float]:
384
+ """Train using HuggingFace Trainer with mixed effects support for MLM.
385
+
386
+ Parameters
387
+ ----------
388
+ items : list[Item]
389
+ Training items with unfilled_slots.
390
+ labels : list[list[str]]
391
+ Training labels as list of lists (one token per unfilled slot).
392
+ participant_ids : list[str]
393
+ Participant IDs.
394
+ validation_items : list[Item] | None
395
+ Validation items.
396
+ validation_labels : list[list[str]] | None
397
+ Validation labels.
398
+
399
+ Returns
400
+ -------
401
+ dict[str, float]
402
+ Training metrics.
403
+ """
404
+ # Convert items to HuggingFace Dataset with masking
405
+ train_dataset = cloze_items_to_dataset(
406
+ items=items,
407
+ labels=labels,
408
+ participant_ids=participant_ids,
409
+ tokenizer=self.tokenizer,
410
+ max_length=self.config.max_length,
411
+ )
412
+
413
+ eval_dataset = None
414
+ if validation_items is not None and validation_labels is not None:
415
+ val_participant_ids = (
416
+ ["_validation_"] * len(validation_items)
417
+ if self.config.mixed_effects.mode != "fixed"
418
+ else ["_fixed_"] * len(validation_items)
419
+ )
420
+ eval_dataset = cloze_items_to_dataset(
421
+ items=validation_items,
422
+ labels=validation_labels,
423
+ participant_ids=val_participant_ids,
424
+ tokenizer=self.tokenizer,
425
+ max_length=self.config.max_length,
426
+ )
427
+
428
+ # Use the model directly (no wrapper needed for MLM models)
429
+ # The model is already compatible with HuggingFace Trainer
430
+ wrapped_model = self.model
431
+
432
+ # Create data collator
433
+ data_collator = ClozeDataCollator(tokenizer=self.tokenizer)
434
+
435
+ # Create training arguments with checkpointing
436
+ with tempfile.TemporaryDirectory() as tmpdir:
437
+ checkpoint_dir = Path(tmpdir) / "checkpoints"
438
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
439
+
440
+ training_args = TrainingArguments(
441
+ output_dir=str(checkpoint_dir),
442
+ num_train_epochs=self.config.num_epochs,
443
+ per_device_train_batch_size=self.config.batch_size,
444
+ per_device_eval_batch_size=self.config.batch_size,
445
+ learning_rate=self.config.learning_rate,
446
+ logging_steps=10,
447
+ eval_strategy="epoch" if eval_dataset is not None else "no",
448
+ save_strategy="epoch",
449
+ save_total_limit=1,
450
+ load_best_model_at_end=False,
451
+ report_to="none",
452
+ remove_unused_columns=False,
453
+ use_cpu=self.config.device == "cpu",
454
+ )
455
+
456
+ # Create metrics computation function
457
+ def compute_metrics_fn(eval_pred: object) -> dict[str, float]:
458
+ from bead.active_learning.trainers.metrics import ( # noqa: PLC0415
459
+ compute_cloze_metrics,
460
+ )
461
+
462
+ return compute_cloze_metrics(eval_pred, tokenizer=self.tokenizer)
463
+
464
+ # Import here to avoid circular import
465
+ from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
466
+ ClozeMLMTrainer,
467
+ )
468
+
469
+ # Create trainer
470
+ trainer = ClozeMLMTrainer(
471
+ model=wrapped_model,
472
+ args=training_args,
473
+ train_dataset=train_dataset,
474
+ eval_dataset=eval_dataset,
475
+ data_collator=data_collator,
476
+ tokenizer=self.tokenizer,
477
+ random_effects_manager=self.random_effects,
478
+ compute_metrics=compute_metrics_fn,
479
+ )
480
+
481
+ # Train
482
+ train_result = trainer.train()
483
+
484
+ # Get training metrics
485
+ train_metrics = trainer.evaluate(eval_dataset=train_dataset)
486
+ metrics: dict[str, float] = {
487
+ "train_loss": float(train_result.training_loss),
488
+ "train_accuracy": train_metrics.get("eval_accuracy", 0.0),
489
+ }
490
+
491
+ # Get validation metrics if eval_dataset was provided
492
+ if eval_dataset is not None:
493
+ val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
494
+ metrics["val_accuracy"] = val_metrics.get("eval_accuracy", 0.0)
495
+
496
+ return metrics
497
+
498
+ def _train_with_custom_loop(
499
+ self,
500
+ items: list[Item],
501
+ labels: list[list[str]],
502
+ participant_ids: list[str],
503
+ validation_items: list[Item] | None,
504
+ validation_labels: list[list[str]] | None,
505
+ ) -> dict[str, float]:
506
+ """Train using custom loop for random_slopes mode.
507
+
508
+ Parameters
509
+ ----------
510
+ items : list[Item]
511
+ Training items with unfilled_slots.
512
+ labels : list[list[str]]
513
+ Training labels as list of lists.
514
+ participant_ids : list[str]
515
+ Participant IDs.
516
+ validation_items : list[Item] | None
517
+ Validation items.
518
+ validation_labels : list[list[str]] | None
519
+ Validation labels.
520
+
521
+ Returns
522
+ -------
523
+ dict[str, float]
524
+ Training metrics.
525
+ """
526
+ # Build optimizer parameters
527
+ params_to_optimize = list(self.model.parameters())
528
+
529
+ # Add random effects parameters for random_slopes
530
+ for head in self.random_effects.slopes.values():
531
+ params_to_optimize.extend(head.parameters())
532
+
533
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
534
+
535
+ self.model.train()
536
+
537
+ for _epoch in range(self.config.num_epochs):
538
+ n_batches = (
539
+ len(items) + self.config.batch_size - 1
540
+ ) // self.config.batch_size
541
+ epoch_loss = 0.0
542
+
543
+ for i in range(n_batches):
544
+ start_idx = i * self.config.batch_size
545
+ end_idx = min(start_idx + self.config.batch_size, len(items))
546
+
547
+ batch_items = items[start_idx:end_idx]
548
+ batch_labels = labels[start_idx:end_idx]
549
+ batch_participant_ids = participant_ids[start_idx:end_idx]
550
+
551
+ # Prepare inputs with masking
552
+ tokenized, masked_positions = self._prepare_inputs_and_masks(
553
+ batch_items
554
+ )
555
+
556
+ # Tokenize labels to get target token IDs
557
+ target_token_ids = []
558
+ for label_list in batch_labels:
559
+ token_ids = []
560
+ for token in label_list:
561
+ tid = self.tokenizer.encode(token, add_special_tokens=False)[0]
562
+ token_ids.append(tid)
563
+ target_token_ids.append(token_ids)
564
+
565
+ # Use participant-specific MLM heads for random_slopes
566
+ all_logits = []
567
+ for j, pid in enumerate(batch_participant_ids):
568
+ # Get participant-specific MLM head
569
+ participant_head = self.random_effects.get_slopes(
570
+ pid,
571
+ fixed_head=copy.deepcopy(self.mlm_head),
572
+ create_if_missing=True,
573
+ )
574
+
575
+ # Get encoder outputs for this item
576
+ item_inputs = {k: v[j : j + 1] for k, v in tokenized.items()}
577
+ encoder_outputs_j = self.encoder(**item_inputs)
578
+
579
+ # Run participant-specific MLM head
580
+ logits_j = participant_head(encoder_outputs_j.last_hidden_state)
581
+ all_logits.append(logits_j)
582
+
583
+ logits = torch.cat(all_logits, dim=0)
584
+
585
+ # Compute loss only on masked positions
586
+ losses = []
587
+ for j, (masked_pos, target_ids) in enumerate(
588
+ zip(masked_positions, target_token_ids, strict=True)
589
+ ):
590
+ for pos, target_id in zip(masked_pos, target_ids, strict=True):
591
+ if pos < logits.shape[1]:
592
+ loss_j = torch.nn.functional.cross_entropy(
593
+ logits[j, pos : pos + 1],
594
+ torch.tensor([target_id], device=self.config.device),
595
+ )
596
+ losses.append(loss_j)
597
+
598
+ if losses:
599
+ loss_nll = torch.stack(losses).mean()
600
+ else:
601
+ loss_nll = torch.tensor(0.0, device=self.config.device)
602
+
603
+ # Add prior regularization
604
+ loss_prior = self.random_effects.compute_prior_loss()
605
+ loss = loss_nll + loss_prior
606
+
607
+ optimizer.zero_grad()
608
+ loss.backward()
609
+ optimizer.step()
610
+
611
+ epoch_loss += loss.item()
612
+
613
+ epoch_loss = epoch_loss / n_batches
614
+
615
+ metrics: dict[str, float] = {
616
+ "train_loss": epoch_loss,
617
+ }
618
+
619
+ # Compute training accuracy
620
+ train_predictions = self._do_predict(items, participant_ids)
621
+ correct = 0
622
+ total = 0
623
+ for pred, label in zip(train_predictions, labels, strict=True):
624
+ # pred.predicted_class is comma-separated tokens
625
+ pred_tokens = pred.predicted_class.split(", ")
626
+ for pt, lt in zip(pred_tokens, label, strict=True):
627
+ if pt.lower() == lt.lower():
628
+ correct += 1
629
+ total += 1
630
+ if total > 0:
631
+ metrics["train_accuracy"] = correct / total
632
+
633
+ return metrics
634
+
635
+ def _do_predict(
636
+ self, items: list[Item], participant_ids: list[str]
637
+ ) -> list[ModelPrediction]:
638
+ """Perform cloze model prediction.
639
+
640
+ Parameters
641
+ ----------
642
+ items : list[Item]
643
+ Items to predict.
644
+ participant_ids : list[str]
645
+ Normalized participant IDs.
646
+
647
+ Returns
648
+ -------
649
+ list[ModelPrediction]
650
+ Predictions with predicted_class as comma-separated tokens.
651
+ """
652
+ self.model.eval()
653
+
654
+ # Prepare inputs with masking
655
+ tokenized, masked_positions = self._prepare_inputs_and_masks(items)
656
+
657
+ with torch.no_grad():
658
+ if self.config.mixed_effects.mode == "fixed":
659
+ # Standard MLM prediction
660
+ outputs = self.model(**tokenized)
661
+ logits = outputs.logits
662
+
663
+ elif self.config.mixed_effects.mode == "random_intercepts":
664
+ # Get encoder outputs
665
+ encoder_outputs = self.encoder(**tokenized)
666
+ logits = self.mlm_head(encoder_outputs.last_hidden_state)
667
+
668
+ # Add participant-specific bias
669
+ vocab_size = self.tokenizer.vocab_size
670
+ for j, pid in enumerate(participant_ids):
671
+ bias = self.random_effects.get_intercepts(
672
+ pid,
673
+ n_classes=vocab_size,
674
+ param_name="mu",
675
+ create_if_missing=False,
676
+ )
677
+ # Add to all masked positions
678
+ for pos in masked_positions[j]:
679
+ if pos < logits.shape[1]:
680
+ logits[j, pos] = logits[j, pos] + bias
681
+
682
+ elif self.config.mixed_effects.mode == "random_slopes":
683
+ # Use participant-specific MLM heads
684
+ all_logits = []
685
+ for j, pid in enumerate(participant_ids):
686
+ # Get participant-specific MLM head
687
+ participant_head = self.random_effects.get_slopes(
688
+ pid,
689
+ fixed_head=copy.deepcopy(self.mlm_head),
690
+ create_if_missing=False,
691
+ )
692
+
693
+ # Get encoder outputs
694
+ item_inputs = {k: v[j : j + 1] for k, v in tokenized.items()}
695
+ encoder_outputs_j = self.encoder(**item_inputs)
696
+
697
+ # Run participant-specific MLM head
698
+ logits_j = participant_head(encoder_outputs_j.last_hidden_state)
699
+ all_logits.append(logits_j)
700
+
701
+ logits = torch.cat(all_logits, dim=0)
702
+
703
+ # Get argmax at masked positions
704
+ predictions = []
705
+ for i, masked_pos in enumerate(masked_positions):
706
+ predicted_tokens = []
707
+ for pos in masked_pos:
708
+ if pos < logits.shape[1]:
709
+ # Get token ID with highest probability
710
+ token_id = torch.argmax(logits[i, pos]).item()
711
+ # Decode token
712
+ token = self.tokenizer.decode([token_id])
713
+ predicted_tokens.append(token.strip())
714
+
715
+ # Join with comma for multi-slot items
716
+ predicted_class = ", ".join(predicted_tokens)
717
+
718
+ predictions.append(
719
+ ModelPrediction(
720
+ item_id=str(items[i].id),
721
+ probabilities={}, # Not applicable for generation
722
+ predicted_class=predicted_class,
723
+ confidence=1.0, # Not applicable for generation
724
+ )
725
+ )
726
+
727
+ return predictions
728
+
729
+ def _do_predict_proba(
730
+ self, items: list[Item], participant_ids: list[str]
731
+ ) -> np.ndarray:
732
+ """Perform cloze model probability prediction.
733
+
734
+ For cloze tasks, returns empty array as probabilities are not typically
735
+ used for evaluation.
736
+
737
+ Parameters
738
+ ----------
739
+ items : list[Item]
740
+ Items to predict.
741
+ participant_ids : list[str]
742
+ Normalized participant IDs.
743
+
744
+ Returns
745
+ -------
746
+ np.ndarray
747
+ Empty array of shape (n_items, 0).
748
+ """
749
+ return np.zeros((len(items), 0))
750
+
751
+ def _save_model_components(self, save_path: Path) -> None:
752
+ """Save model-specific components (model, tokenizer).
753
+
754
+ Parameters
755
+ ----------
756
+ save_path : Path
757
+ Directory to save to.
758
+ """
759
+ self.model.save_pretrained(save_path / "model")
760
+ self.tokenizer.save_pretrained(save_path / "model")
761
+
762
+ def _get_save_state(self) -> dict[str, object]:
763
+ """Get model-specific state to save in config.json.
764
+
765
+ Returns
766
+ -------
767
+ dict[str, object]
768
+ State dictionary to include in config.json.
769
+ """
770
+ return {}
771
+
772
+ def _load_model_components(self, load_path: Path) -> None:
773
+ """Load model-specific components.
774
+
775
+ Parameters
776
+ ----------
777
+ load_path : Path
778
+ Directory to load from.
779
+ """
780
+ # Load config.json to reconstruct config
781
+ with open(load_path / "config.json") as f:
782
+ import json # noqa: PLC0415
783
+
784
+ config_dict = json.load(f)
785
+
786
+ # Reconstruct MixedEffectsConfig if needed
787
+ if "mixed_effects" in config_dict and isinstance(
788
+ config_dict["mixed_effects"], dict
789
+ ):
790
+ from bead.active_learning.config import MixedEffectsConfig # noqa: PLC0415
791
+
792
+ config_dict["mixed_effects"] = MixedEffectsConfig(
793
+ **config_dict["mixed_effects"]
794
+ )
795
+
796
+ # Reconstruct ClozeModelConfig
797
+ self.config = ClozeModelConfig(**config_dict)
798
+
799
+ # Load model
800
+ self.model = AutoModelForMaskedLM.from_pretrained(load_path / "model")
801
+ self.tokenizer = AutoTokenizer.from_pretrained(load_path / "model")
802
+
803
+ # Re-extract components
804
+ if hasattr(self.model, "bert"):
805
+ self.encoder = self.model.bert
806
+ self.mlm_head = self.model.cls
807
+ elif hasattr(self.model, "roberta"):
808
+ self.encoder = self.model.roberta
809
+ self.mlm_head = self.model.lm_head
810
+ else:
811
+ self.encoder = self.model.base_model
812
+ self.mlm_head = self.model.lm_head
813
+
814
+ self.model.to(self.config.device)
815
+
816
+ def _restore_training_state(self, config_dict: dict[str, object]) -> None:
817
+ """Restore model-specific training state from config_dict.
818
+
819
+ Parameters
820
+ ----------
821
+ config_dict : dict[str, object]
822
+ Configuration dictionary with training state.
823
+ """
824
+ # ClozeModel doesn't have additional training state to restore
825
+ pass
826
+
827
+ def _get_n_classes_for_random_effects(self) -> int:
828
+ """Get the number of classes for initializing RandomEffectsManager.
829
+
830
+ For cloze models, this is the vocabulary size.
831
+
832
+ Returns
833
+ -------
834
+ int
835
+ Vocabulary size.
836
+ """
837
+ return self.tokenizer.vocab_size
838
+
839
+ def _initialize_random_effects(self, n_classes: int) -> None:
840
+ """Initialize the RandomEffectsManager.
841
+
842
+ Parameters
843
+ ----------
844
+ n_classes : int
845
+ Vocabulary size for cloze models.
846
+ """
847
+ self.random_effects = RandomEffectsManager(
848
+ self.config.mixed_effects,
849
+ vocab_size=n_classes, # For random intercepts (bias on logits)
850
+ )
851
+
852
+ def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
853
+ """Get the fixed head for random effects (classifier_head, etc.).
854
+
855
+ For cloze models, this is the MLM head.
856
+
857
+ Returns
858
+ -------
859
+ torch.nn.Module | None
860
+ The MLM head, or None if not applicable.
861
+ """
862
+ return self.mlm_head