bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,551 @@
1
+ """Mixed effects trainer for HuggingFace models.
2
+
3
+ This module provides a custom Trainer that handles participant-level
4
+ random effects (intercepts and slopes) while using HuggingFace Trainer
5
+ infrastructure for optimization, checkpointing, and device management.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import Callable
11
+ from typing import TYPE_CHECKING
12
+
13
+ import torch
14
+ import torch.nn.functional
15
+ from transformers import Trainer, TrainingArguments
16
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
17
+
18
+ from bead.active_learning.models.random_effects import RandomEffectsManager
19
+
20
+ if TYPE_CHECKING:
21
+ from collections.abc import Mapping
22
+
23
+ import torch
24
+ from torch.utils.data import Dataset
25
+ from transformers import PreTrainedTokenizerBase
26
+
27
+
28
+ class MixedEffectsTrainer(Trainer):
29
+ """HuggingFace Trainer with mixed effects support.
30
+
31
+ Extends HuggingFace Trainer to handle participant-level random effects
32
+ (random intercepts and random slopes) while using Trainer's
33
+ optimization, checkpointing, and device management.
34
+
35
+ The key innovation is overriding compute_loss to apply participant-specific
36
+ adjustments to model outputs before computing the loss. This preserves
37
+ the mixed effects functionality while using HuggingFace infrastructure.
38
+
39
+ Parameters
40
+ ----------
41
+ model : torch.nn.Module
42
+ The model to train (must support mixed effects).
43
+ args : TrainingArguments
44
+ HuggingFace training arguments.
45
+ train_dataset : Dataset
46
+ Training dataset (must include 'participant_id' field).
47
+ eval_dataset : Dataset | None
48
+ Evaluation dataset (optional).
49
+ random_effects_manager : RandomEffectsManager
50
+ Manager for participant-level random effects.
51
+ data_collator : Callable | None
52
+ Data collator (optional, uses default if None).
53
+ tokenizer : PreTrainedTokenizerBase | None
54
+ Tokenizer (optional, for data collation).
55
+ compute_metrics : Callable[[object], dict[str, float]] | None
56
+ Metrics computation function (optional).
57
+
58
+ Attributes
59
+ ----------
60
+ random_effects_manager : RandomEffectsManager
61
+ Manager for random effects.
62
+ mixed_effects_config : MixedEffectsConfig
63
+ Mixed effects configuration.
64
+
65
+ Examples
66
+ --------
67
+ >>> from transformers import AutoModelForSequenceClassification, TrainingArguments
68
+ >>> from datasets import Dataset
69
+ >>> config = MixedEffectsConfig(mode='random_intercepts')
70
+ >>> manager = RandomEffectsManager(config, n_classes=2)
71
+ >>> model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
72
+ >>> trainer = MixedEffectsTrainer(
73
+ ... model=model,
74
+ ... args=TrainingArguments(output_dir='./output'),
75
+ ... train_dataset=dataset,
76
+ ... random_effects_manager=manager
77
+ ... )
78
+ >>> trainer.train()
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ model: torch.nn.Module,
84
+ args: TrainingArguments,
85
+ train_dataset: Dataset,
86
+ random_effects_manager: RandomEffectsManager,
87
+ eval_dataset: Dataset | None = None,
88
+ data_collator: (
89
+ Callable[
90
+ [list[dict[str, torch.Tensor | str | int | float]]],
91
+ dict[str, torch.Tensor | list[str]],
92
+ ]
93
+ | None
94
+ ) = None,
95
+ tokenizer: PreTrainedTokenizerBase | None = None,
96
+ compute_metrics: Callable[[object], dict[str, float]] | None = None,
97
+ ) -> None:
98
+ """Initialize mixed effects trainer.
99
+
100
+ Parameters
101
+ ----------
102
+ model : torch.nn.Module
103
+ Model to train.
104
+ args : TrainingArguments
105
+ Training arguments.
106
+ train_dataset : Dataset
107
+ Training dataset.
108
+ random_effects_manager : RandomEffectsManager
109
+ Random effects manager.
110
+ eval_dataset : Dataset | None
111
+ Evaluation dataset.
112
+ data_collator : Callable | None
113
+ Data collator.
114
+ tokenizer : PreTrainedTokenizerBase | None
115
+ Tokenizer.
116
+ compute_metrics : Callable[[object], dict[str, float]] | None
117
+ Metrics computation function.
118
+ """
119
+ super().__init__(
120
+ model=model,
121
+ args=args,
122
+ train_dataset=train_dataset,
123
+ eval_dataset=eval_dataset,
124
+ data_collator=data_collator,
125
+ processing_class=tokenizer,
126
+ compute_metrics=compute_metrics,
127
+ )
128
+ self.random_effects_manager = random_effects_manager
129
+ self.mixed_effects_config = random_effects_manager.config
130
+
131
+ def compute_loss(
132
+ self,
133
+ model: torch.nn.Module,
134
+ inputs: Mapping[str, torch.Tensor],
135
+ return_outputs: bool = False,
136
+ num_items_in_batch: int | None = None,
137
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute loss with mixed effects adjustments.
139
+
140
+ Overrides HuggingFace Trainer's compute_loss to:
141
+ 1. Get standard model outputs
142
+ 2. Apply participant-specific adjustments (intercepts/slopes)
143
+ 3. Compute loss with prior regularization
144
+ 4. Return loss (and optionally outputs)
145
+
146
+ Parameters
147
+ ----------
148
+ model : torch.nn.Module
149
+ Model to compute loss for.
150
+ inputs : Mapping[str, torch.Tensor]
151
+ Input batch (must include 'participant_id' if mixed effects).
152
+ participant_id should be a list[str] in the dataset, but will be
153
+ converted to tensor by data collator.
154
+ return_outputs : bool
155
+ Whether to return model outputs.
156
+
157
+ Returns
158
+ -------
159
+ torch.Tensor | tuple[torch.Tensor, torch.Tensor]
160
+ Loss tensor, or (loss, outputs) if return_outputs=True.
161
+ """
162
+ # Get labels and participant IDs
163
+ labels = inputs.get("labels")
164
+ participant_ids = inputs.get("participant_id")
165
+
166
+ # For random_slopes mode, pass participant_id to model (wrapper handles routing)
167
+ # For other modes, remove participant_id from inputs
168
+ if self.mixed_effects_config.mode == "random_slopes":
169
+ # RandomSlopesModelWrapper expects participant_id in forward()
170
+ model_inputs = {k: v for k, v in inputs.items() if k != "labels"}
171
+ else:
172
+ excluded = ("labels", "participant_id")
173
+ model_inputs = {k: v for k, v in inputs.items() if k not in excluded}
174
+
175
+ # Standard forward pass
176
+ outputs = model(**model_inputs)
177
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs
178
+
179
+ # Apply mixed effects adjustments
180
+ if self.mixed_effects_config.mode == "random_intercepts":
181
+ # Apply participant-specific biases to logits
182
+ if participant_ids is not None:
183
+ batch_size = logits.shape[0]
184
+ # Handle participant_ids: could be tensor of indices or list of strings
185
+ # In our case, we store participant_ids as strings in dataset
186
+ # The data collator will need to handle this specially
187
+ for i in range(batch_size):
188
+ # Extract participant ID - data collator provides as list[str]
189
+ if isinstance(participant_ids, list):
190
+ pid_str = str(participant_ids[i])
191
+ elif isinstance(participant_ids, torch.Tensor):
192
+ # Fallback: if somehow tensor, convert
193
+ pid_elem = participant_ids[i]
194
+ pid_raw = pid_elem.item() if pid_elem.numel() == 1 else pid_elem
195
+ pid_str = str(pid_raw)
196
+ else:
197
+ pid_str = str(participant_ids[i])
198
+
199
+ # Get bias for this participant
200
+ # For binary: n_classes=1 (scalar bias)
201
+ n_classes = logits.shape[1] if logits.dim() > 1 else 1
202
+ bias = self.random_effects_manager.get_intercepts(
203
+ pid_str,
204
+ n_classes=n_classes,
205
+ param_name="mu",
206
+ create_if_missing=True,
207
+ )
208
+ # Ensure bias is on same device as logits
209
+ bias = bias.to(logits.device)
210
+ # For binary, bias is scalar, add to logits
211
+ if logits.dim() == 1:
212
+ bias_val = bias[0] if bias.numel() > 0 else 0
213
+ logits[i] = logits[i] + bias_val
214
+ else:
215
+ logits[i] = logits[i] + bias
216
+
217
+ elif self.mixed_effects_config.mode == "random_slopes":
218
+ # Random slopes are handled by RandomSlopesModelWrapper in forward()
219
+ # The model routes each sample through participant-specific heads
220
+ # Logits already incorporate random slopes - nothing to do here
221
+ pass
222
+
223
+ # Compute data loss
224
+ if labels is not None:
225
+ # Check if this is regression (continuous labels) or classification
226
+ # Regression: labels are float, logits shape is (batch, 1)
227
+ # Classification: labels are int/long, logits shape varies by task
228
+ if labels.dtype.is_floating_point:
229
+ # Regression task: use MSE loss
230
+ if logits.dim() == 2 and logits.shape[1] == 1:
231
+ # Squeeze to (batch,)
232
+ preds = logits.squeeze(1)
233
+ elif logits.dim() == 1:
234
+ preds = logits
235
+ else:
236
+ # Unexpected shape, use first column
237
+ preds = logits[:, 0]
238
+ loss = torch.nn.functional.mse_loss(preds, labels.float())
239
+ elif logits.dim() == 1 or (logits.dim() == 2 and logits.shape[1] == 1):
240
+ # Binary classification
241
+ loss_fct = torch.nn.functional.binary_cross_entropy_with_logits
242
+ if labels.dim() == 0:
243
+ labels = labels.unsqueeze(0)
244
+ if logits.dim() == 1:
245
+ logits = logits.unsqueeze(1)
246
+ loss = loss_fct(logits.squeeze(1), labels.float())
247
+ else:
248
+ # Multi-class classification
249
+ loss_fct = torch.nn.functional.cross_entropy
250
+ loss = loss_fct(logits, labels.long())
251
+ else:
252
+ # No labels provided (unsupervised)
253
+ loss = torch.tensor(0.0, device=logits.device)
254
+
255
+ # Add prior regularization loss
256
+ loss_prior = self.random_effects_manager.compute_prior_loss()
257
+ if loss_prior.device != loss.device:
258
+ loss_prior = loss_prior.to(loss.device)
259
+ loss = loss + loss_prior
260
+
261
+ if return_outputs:
262
+ # Create output object with adjusted logits
263
+ adjusted_outputs = SequenceClassifierOutput(logits=logits)
264
+ return (loss, adjusted_outputs)
265
+ return loss
266
+
267
+ def create_optimizer(self) -> None:
268
+ """Create optimizer with all parameters including participant heads.
269
+
270
+ For random_slopes mode, this method collects parameters from:
271
+ 1. The encoder (via model.encoder or model.model.encoder)
272
+ 2. The fixed classifier head
273
+ 3. All participant-specific heads (slopes)
274
+
275
+ For other modes, delegates to parent implementation.
276
+ """
277
+ if self.optimizer is not None:
278
+ # Optimizer already exists
279
+ return
280
+
281
+ if self.mixed_effects_config.mode == "random_slopes":
282
+ # Collect parameters for random_slopes mode
283
+ optimizer_grouped_parameters: list[dict[str, object]] = []
284
+
285
+ # Check if model has get_all_parameters method (RandomSlopesModelWrapper)
286
+ if hasattr(self.model, "get_all_parameters"):
287
+ all_params = self.model.get_all_parameters()
288
+ optimizer_grouped_parameters.append(
289
+ {
290
+ "params": all_params,
291
+ "lr": self.args.learning_rate,
292
+ }
293
+ )
294
+ else:
295
+ # Fallback: collect standard model parameters plus slope parameters
296
+ optimizer_grouped_parameters.append(
297
+ {
298
+ "params": list(self.model.parameters()),
299
+ "lr": self.args.learning_rate,
300
+ }
301
+ )
302
+
303
+ # Add participant head parameters from random_effects_manager
304
+ if hasattr(self.random_effects_manager, "slopes"):
305
+ for head in self.random_effects_manager.slopes.values():
306
+ if hasattr(head, "parameters"):
307
+ optimizer_grouped_parameters.append(
308
+ {
309
+ "params": list(head.parameters()),
310
+ "lr": self.args.learning_rate,
311
+ }
312
+ )
313
+
314
+ # Create AdamW optimizer
315
+ self.optimizer = torch.optim.AdamW(
316
+ optimizer_grouped_parameters,
317
+ lr=self.args.learning_rate,
318
+ weight_decay=self.args.weight_decay,
319
+ )
320
+ else:
321
+ # Use parent implementation for other modes
322
+ super().create_optimizer()
323
+
324
+
325
+ class ClozeMLMTrainer(MixedEffectsTrainer):
326
+ """Custom trainer for cloze (MLM) tasks with custom masking positions.
327
+
328
+ Extends MixedEffectsTrainer to handle MLM loss computation only on
329
+ specific masked positions (from unfilled_slots) rather than all positions.
330
+
331
+ Parameters
332
+ ----------
333
+ model : torch.nn.Module
334
+ MLM model (AutoModelForMaskedLM or wrapper).
335
+ args : TrainingArguments
336
+ Training arguments.
337
+ train_dataset : Dataset
338
+ Training dataset (must include 'masked_positions' and 'target_token_ids').
339
+ random_effects_manager : RandomEffectsManager
340
+ Random effects manager.
341
+ eval_dataset : Dataset | None
342
+ Evaluation dataset.
343
+ data_collator : Callable | None
344
+ Data collator (should be ClozeDataCollator).
345
+ tokenizer : PreTrainedTokenizerBase | None
346
+ Tokenizer.
347
+ compute_metrics : Callable[[object], dict[str, float]] | None
348
+ Metrics computation function.
349
+
350
+ Examples
351
+ --------
352
+ >>> from transformers import AutoModelForMaskedLM, TrainingArguments
353
+ >>> model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')
354
+ >>> trainer = ClozeMLMTrainer(
355
+ ... model=model,
356
+ ... args=TrainingArguments(output_dir='./output'),
357
+ ... train_dataset=dataset,
358
+ ... random_effects_manager=manager
359
+ ... )
360
+ >>> trainer.train()
361
+ """
362
+
363
+ def compute_loss(
364
+ self,
365
+ model: torch.nn.Module,
366
+ inputs: Mapping[str, torch.Tensor | list[str] | list[list[int]]],
367
+ return_outputs: bool = False,
368
+ num_items_in_batch: int | None = None,
369
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
370
+ """Compute MLM loss only on masked positions.
371
+
372
+ Overrides MixedEffectsTrainer's compute_loss to:
373
+ 1. Get model outputs (logits for all positions)
374
+ 2. Apply participant-specific adjustments (intercepts) if needed
375
+ 3. Compute cross-entropy loss only on masked positions
376
+ 4. Add prior regularization
377
+ 5. Return loss (and optionally outputs)
378
+
379
+ Parameters
380
+ ----------
381
+ model : torch.nn.Module
382
+ Model to compute loss for.
383
+ inputs : Mapping[str, torch.Tensor | list[str] | list[list[int]]]
384
+ Input batch with:
385
+ - Standard tokenized inputs (input_ids, attention_mask, etc.)
386
+ - participant_id: list[str]
387
+ - masked_positions: list[list[int]] - masked token positions per item
388
+ - target_token_ids: list[list[int]] - target token IDs per masked position
389
+ return_outputs : bool
390
+ Whether to return model outputs.
391
+ num_items_in_batch : int | None
392
+ Unused, kept for compatibility.
393
+
394
+ Returns
395
+ -------
396
+ torch.Tensor | tuple[torch.Tensor, torch.Tensor]
397
+ Loss tensor, or (loss, outputs) if return_outputs=True.
398
+ """
399
+ # Extract cloze-specific fields
400
+ participant_ids = inputs.get("participant_id")
401
+ masked_positions = inputs.get("masked_positions", [])
402
+ target_token_ids = inputs.get("target_token_ids", [])
403
+
404
+ # Remove these from inputs for model forward pass
405
+ excluded = ("labels", "participant_id", "masked_positions", "target_token_ids")
406
+ model_inputs = {k: v for k, v in inputs.items() if k not in excluded}
407
+
408
+ # Standard forward pass
409
+ outputs = model(**model_inputs)
410
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs
411
+ # logits shape: (batch, seq_len, vocab_size)
412
+
413
+ # Apply mixed effects adjustments for random_intercepts
414
+ if self.mixed_effects_config.mode == "random_intercepts":
415
+ if participant_ids is not None and isinstance(participant_ids, list):
416
+ vocab_size = logits.shape[2]
417
+ batch_size = logits.shape[0]
418
+ for i in range(batch_size):
419
+ pid_str = str(participant_ids[i])
420
+ # Get bias for this participant (vocab_size,)
421
+ bias = self.random_effects_manager.get_intercepts(
422
+ pid_str,
423
+ n_classes=vocab_size,
424
+ param_name="mu",
425
+ create_if_missing=True,
426
+ )
427
+ bias = bias.to(logits.device)
428
+ # Add bias to all masked positions for this item
429
+ in_range = i < len(masked_positions)
430
+ if in_range and isinstance(masked_positions[i], list):
431
+ for pos in masked_positions[i]:
432
+ if pos < logits.shape[1]:
433
+ logits[i, pos] = logits[i, pos] + bias
434
+
435
+ # Compute loss only on masked positions
436
+ losses: list[torch.Tensor] = []
437
+ if isinstance(masked_positions, list) and isinstance(target_token_ids, list):
438
+ for j, (masked_pos, target_ids) in enumerate(
439
+ zip(masked_positions, target_token_ids, strict=True)
440
+ ):
441
+ if j >= logits.shape[0]:
442
+ continue
443
+ if isinstance(masked_pos, list) and isinstance(target_ids, list):
444
+ for pos, target_id in zip(masked_pos, target_ids, strict=True):
445
+ if pos < logits.shape[1]:
446
+ # Cross-entropy loss for this position
447
+ # logits[j, pos] shape: (vocab_size,)
448
+ # target_id: int
449
+ # Need shape (1, vocab_size) for logits and (1,) for target
450
+ pos_logits = logits[j, pos].unsqueeze(0) # (1, vocab_size)
451
+ pos_target = torch.tensor(
452
+ [target_id], device=logits.device, dtype=torch.long
453
+ ) # (1,)
454
+ loss_j = torch.nn.functional.cross_entropy(
455
+ pos_logits, pos_target
456
+ )
457
+ losses.append(loss_j)
458
+
459
+ if losses:
460
+ loss_nll = torch.stack(losses).mean()
461
+ else:
462
+ loss_nll = torch.tensor(0.0, device=logits.device)
463
+
464
+ # Add prior regularization loss
465
+ loss_prior = self.random_effects_manager.compute_prior_loss()
466
+ if loss_prior.device != loss_nll.device:
467
+ loss_prior = loss_prior.to(loss_nll.device)
468
+ loss = loss_nll + loss_prior
469
+
470
+ if return_outputs:
471
+ # Return outputs with logits
472
+ adjusted_outputs = MaskedLMOutput(logits=logits)
473
+ return (loss, adjusted_outputs)
474
+ return loss
475
+
476
+ def prediction_step(
477
+ self,
478
+ model: torch.nn.Module,
479
+ inputs: dict[str, torch.Tensor | list[str] | list[list[int]]],
480
+ prediction_loss_only: bool,
481
+ ignore_keys: list[str] | None = None,
482
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
483
+ """Perform a prediction step with cloze-specific label encoding.
484
+
485
+ Creates labels tensor encoding target_token_ids at masked_positions
486
+ with -100 elsewhere (HuggingFace ignore index convention). This enables
487
+ compute_cloze_metrics() to evaluate predictions at the correct positions.
488
+
489
+ Parameters
490
+ ----------
491
+ model : torch.nn.Module
492
+ Model to use for prediction.
493
+ inputs : dict[str, torch.Tensor | list[str] | list[list[int]]]
494
+ Input batch with:
495
+ - Standard tokenized inputs (input_ids, attention_mask, etc.)
496
+ - participant_id: list[str]
497
+ - masked_positions: list[list[int]] - masked token positions per item
498
+ - target_token_ids: list[list[int]] - target token IDs per position
499
+ prediction_loss_only : bool
500
+ Whether to only return loss.
501
+ ignore_keys : list[str] | None
502
+ Keys to ignore (unused).
503
+
504
+ Returns
505
+ -------
506
+ tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]
507
+ (loss, logits, labels) tuple where labels encodes target tokens
508
+ at masked positions with -100 elsewhere.
509
+ """
510
+ # Extract cloze-specific fields
511
+ masked_positions = inputs.get("masked_positions", [])
512
+ target_token_ids = inputs.get("target_token_ids", [])
513
+
514
+ # Filter inputs for model forward pass
515
+ model_inputs = {
516
+ k: v
517
+ for k, v in inputs.items()
518
+ if k not in ("participant_id", "masked_positions", "target_token_ids")
519
+ }
520
+
521
+ # Get predictions from parent (which handles compute_loss internally)
522
+ loss, logits, _ = super().prediction_step(
523
+ model, model_inputs, prediction_loss_only, ignore_keys
524
+ )
525
+
526
+ if prediction_loss_only:
527
+ return (loss, None, None)
528
+
529
+ # Build labels tensor: (batch_size, seq_len) with -100 default
530
+ labels = None
531
+ has_masks = isinstance(masked_positions, list)
532
+ has_targets = isinstance(target_token_ids, list)
533
+ if logits is not None and has_masks and has_targets:
534
+ batch_size, seq_len = logits.shape[:2]
535
+ labels = torch.full(
536
+ (batch_size, seq_len), -100, dtype=torch.long, device=logits.device
537
+ )
538
+
539
+ # Fill in target token IDs at masked positions
540
+ for i, (positions, targets) in enumerate(
541
+ zip(masked_positions, target_token_ids, strict=False)
542
+ ):
543
+ if i >= batch_size:
544
+ break
545
+ if isinstance(positions, list) and isinstance(targets, list):
546
+ for pos, target_id in zip(positions, targets, strict=False):
547
+ if isinstance(pos, int) and isinstance(target_id, int):
548
+ if 0 <= pos < seq_len:
549
+ labels[i, pos] = target_id
550
+
551
+ return (loss, logits, labels)