bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,773 @@
1
+ """Free text model for open-ended text generation with GLMM support.
2
+
3
+ Implements seq2seq generation with participant-level random effects using:
4
+ - Random intercepts: Bias on decoder output logits (token probability shifts)
5
+ - Random slopes: LoRA adapters on decoder attention layers
6
+
7
+ Architecture: T5-base or BART-base encoder-decoder model
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional
18
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
19
+
20
+ from bead.active_learning.config import VarianceComponents
21
+ from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
22
+ from bead.active_learning.models.peft_adapter import create_participant_lora_adapter
23
+ from bead.active_learning.models.random_effects import RandomEffectsManager
24
+ from bead.config.active_learning import FreeTextModelConfig
25
+ from bead.items.item import Item
26
+ from bead.items.item_template import ItemTemplate, TaskType
27
+
28
+ __all__ = ["FreeTextModel"]
29
+
30
+
31
+ class FreeTextModel(ActiveLearningModel):
32
+ """Model for free_text tasks with participant-level random effects.
33
+
34
+ Uses seq2seq architecture (T5 or BART) with three modes:
35
+ - Fixed effects: Standard encoder-decoder
36
+ - Random intercepts: Participant-specific bias on output logits
37
+ - Random slopes: Participant-specific LoRA adapters on decoder
38
+
39
+ Parameters
40
+ ----------
41
+ config : FreeTextModelConfig
42
+ Configuration object containing all model parameters.
43
+
44
+ Attributes
45
+ ----------
46
+ config : FreeTextModelConfig
47
+ Model configuration.
48
+ tokenizer : AutoTokenizer
49
+ Seq2seq tokenizer.
50
+ model : AutoModelForSeq2SeqLM
51
+ Base seq2seq model (T5 or BART).
52
+ encoder : nn.Module
53
+ Encoder module.
54
+ base_decoder : nn.Module
55
+ Base decoder module (shared across participants in fixed/random_intercepts).
56
+ lm_head : nn.Module
57
+ Language modeling head (projects decoder output to vocabulary).
58
+ random_effects : RandomEffectsManager
59
+ Manager for participant-level random effects.
60
+ variance_history : list[VarianceComponents]
61
+ Variance component estimates over training.
62
+ _is_fitted : bool
63
+ Whether model has been trained.
64
+
65
+ Examples
66
+ --------
67
+ >>> from uuid import uuid4
68
+ >>> from bead.items.item import Item
69
+ >>> from bead.config.active_learning import FreeTextModelConfig
70
+ >>> items = [
71
+ ... Item(
72
+ ... item_template_id=uuid4(),
73
+ ... rendered_elements={"prompt": "Summarize: The cat sat."}
74
+ ... )
75
+ ... for _ in range(10)
76
+ ... ]
77
+ >>> labels = ["Cat sits."] * 10
78
+ >>> config = FreeTextModelConfig( # doctest: +SKIP
79
+ ... num_epochs=1, batch_size=2, device="cpu"
80
+ ... )
81
+ >>> model = FreeTextModel(config=config) # doctest: +SKIP
82
+ >>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ config: FreeTextModelConfig | None = None,
88
+ ) -> None:
89
+ """Initialize free text model.
90
+
91
+ Parameters
92
+ ----------
93
+ config : FreeTextModelConfig | None
94
+ Configuration object. If None, uses default configuration.
95
+ """
96
+ self.config = config or FreeTextModelConfig()
97
+
98
+ # Validate mixed_effects configuration
99
+ super().__init__(self.config)
100
+
101
+ # Load tokenizer and model
102
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
103
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_name)
104
+
105
+ # Extract encoder, decoder, and lm_head for fine-grained control
106
+ self.encoder = self.model.get_encoder()
107
+ self.base_decoder = self.model.get_decoder()
108
+ self.lm_head = self.model.lm_head
109
+
110
+ self._is_fitted = False
111
+
112
+ # Initialize random effects manager
113
+ self.random_effects: RandomEffectsManager | None = None
114
+ self.variance_history: list[VarianceComponents] = []
115
+
116
+ self.model.to(self.config.device)
117
+
118
+ @property
119
+ def supported_task_types(self) -> list[TaskType]:
120
+ """Get supported task types.
121
+
122
+ Returns
123
+ -------
124
+ list[TaskType]
125
+ List containing "free_text".
126
+ """
127
+ return ["free_text"]
128
+
129
+ def validate_item_compatibility(
130
+ self, item: Item, item_template: ItemTemplate
131
+ ) -> None:
132
+ """Validate item is compatible with free text model.
133
+
134
+ Parameters
135
+ ----------
136
+ item : Item
137
+ Item to validate.
138
+ item_template : ItemTemplate
139
+ Template the item was constructed from.
140
+
141
+ Raises
142
+ ------
143
+ ValueError
144
+ If task_type is not "free_text".
145
+ """
146
+ if item_template.task_type != "free_text":
147
+ raise ValueError(
148
+ f"Expected task_type 'free_text', got '{item_template.task_type}'"
149
+ )
150
+
151
+ def _prepare_inputs(self, items: list[Item]) -> str:
152
+ """Prepare input texts from items.
153
+
154
+ For free text tasks, concatenates all rendered elements as prompt.
155
+
156
+ Parameters
157
+ ----------
158
+ items : list[Item]
159
+ Items to encode.
160
+
161
+ Returns
162
+ -------
163
+ list[str]
164
+ Input texts.
165
+ """
166
+ texts = []
167
+ for item in items:
168
+ # Concatenate all rendered elements as input
169
+ text = " ".join(item.rendered_elements.values())
170
+ texts.append(text)
171
+ return texts
172
+
173
+ def _prepare_training_data(
174
+ self,
175
+ items: list[Item],
176
+ labels: list[str],
177
+ participant_ids: list[str],
178
+ validation_items: list[Item] | None,
179
+ validation_labels: list[str] | None,
180
+ ) -> tuple[
181
+ list[Item],
182
+ list[str],
183
+ list[str],
184
+ list[Item] | None,
185
+ list[str] | None,
186
+ ]:
187
+ """Prepare data for training, including validation.
188
+
189
+ Parameters
190
+ ----------
191
+ items : list[Item]
192
+ Training items.
193
+ labels : list[str]
194
+ Training labels (target text strings).
195
+ participant_ids : list[str]
196
+ Participant identifiers.
197
+ validation_items : list[Item] | None
198
+ Optional validation items.
199
+ validation_labels : list[str] | None
200
+ Optional validation labels.
201
+
202
+ Returns
203
+ -------
204
+ tuple
205
+ Prepared training data: items, labels, participant_ids,
206
+ validation_items, validation_labels.
207
+
208
+ Raises
209
+ ------
210
+ ValueError
211
+ If labels contain empty strings.
212
+ """
213
+ if any(not label for label in labels):
214
+ raise ValueError(
215
+ "labels cannot contain empty strings. "
216
+ "Ensure all labels are non-empty text."
217
+ )
218
+
219
+ val_labels_list: list[str] | None = None
220
+ if validation_items is not None and validation_labels is not None:
221
+ if any(not label for label in validation_labels):
222
+ raise ValueError(
223
+ "validation_labels cannot contain empty strings. "
224
+ "Ensure all validation labels are non-empty text."
225
+ )
226
+ val_labels_list = validation_labels
227
+
228
+ return items, labels, participant_ids, validation_items, val_labels_list
229
+
230
+ def _do_training(
231
+ self,
232
+ items: list[Item],
233
+ labels_numeric: list[str],
234
+ participant_ids: list[str],
235
+ validation_items: list[Item] | None,
236
+ validation_labels_numeric: list[str] | None,
237
+ ) -> dict[str, float]:
238
+ """Perform the actual training logic (custom loop for seq2seq).
239
+
240
+ Parameters
241
+ ----------
242
+ items : list[Item]
243
+ Training items.
244
+ labels_numeric : list[str]
245
+ Training labels (target text strings).
246
+ participant_ids : list[str]
247
+ Participant identifiers.
248
+ validation_items : list[Item] | None
249
+ Optional validation items.
250
+ validation_labels_numeric : list[str] | None
251
+ Optional validation labels.
252
+
253
+ Returns
254
+ -------
255
+ dict[str, float]
256
+ Training metrics.
257
+ """
258
+ # Prepare inputs
259
+ input_texts = self._prepare_inputs(items)
260
+
261
+ # Get actual vocabulary size from lm_head output dimension
262
+ vocab_size = self.lm_head.out_features
263
+
264
+ # Build optimizer parameters based on mode
265
+ params_to_optimize = list(self.model.parameters())
266
+
267
+ # Add random effects parameters
268
+ if self.config.mixed_effects.mode == "random_intercepts":
269
+ for param_dict in self.random_effects.intercepts.values():
270
+ params_to_optimize.extend(param_dict.values())
271
+ elif self.config.mixed_effects.mode == "random_slopes":
272
+ for adapter in self.random_effects.slopes.values():
273
+ params_to_optimize.extend(adapter.get_lora_parameters())
274
+
275
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
276
+
277
+ self.model.train()
278
+
279
+ for _epoch in range(self.config.num_epochs):
280
+ n_batches = (
281
+ len(items) + self.config.batch_size - 1
282
+ ) // self.config.batch_size
283
+ epoch_loss = 0.0
284
+
285
+ for i in range(n_batches):
286
+ start_idx = i * self.config.batch_size
287
+ end_idx = min(start_idx + self.config.batch_size, len(items))
288
+
289
+ batch_input_texts = input_texts[start_idx:end_idx]
290
+ batch_labels = labels_numeric[start_idx:end_idx]
291
+ batch_participant_ids = participant_ids[start_idx:end_idx]
292
+
293
+ # Tokenize inputs and labels
294
+ inputs = self.tokenizer(
295
+ batch_input_texts,
296
+ padding=True,
297
+ truncation=True,
298
+ max_length=self.config.max_input_length,
299
+ return_tensors="pt",
300
+ ).to(self.config.device)
301
+
302
+ # Tokenize targets (labels)
303
+ targets = self.tokenizer(
304
+ text_target=batch_labels,
305
+ padding=True,
306
+ truncation=True,
307
+ max_length=self.config.max_output_length,
308
+ return_tensors="pt",
309
+ ).to(self.config.device)
310
+
311
+ target_ids = targets["input_ids"]
312
+ # Replace pad token id with -100 for loss computation
313
+ target_ids[target_ids == self.tokenizer.pad_token_id] = -100
314
+
315
+ # Forward pass depends on mixed effects mode
316
+ if self.config.mixed_effects.mode == "fixed":
317
+ # Standard seq2seq training
318
+ outputs = self.model(
319
+ **inputs,
320
+ labels=target_ids,
321
+ )
322
+ loss_nll = outputs.loss
323
+
324
+ elif self.config.mixed_effects.mode == "random_intercepts":
325
+ # Get encoder outputs
326
+ encoder_outputs = self.encoder(**inputs)
327
+
328
+ # Run decoder to get logits
329
+ decoder_outputs = self.base_decoder(
330
+ input_ids=targets["input_ids"],
331
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
332
+ encoder_attention_mask=inputs["attention_mask"],
333
+ )
334
+
335
+ # Project to vocabulary
336
+ logits = self.lm_head(decoder_outputs.last_hidden_state)
337
+
338
+ # Add participant-specific bias to logits
339
+ for j, pid in enumerate(batch_participant_ids):
340
+ bias = self.random_effects.get_intercepts(
341
+ pid,
342
+ n_classes=vocab_size,
343
+ param_name="mu",
344
+ create_if_missing=True,
345
+ )
346
+ # bias shape: (vocab_size,)
347
+ # Add to all positions in sequence
348
+ logits[j] = logits[j] + bias
349
+
350
+ # Compute cross-entropy loss
351
+ loss_nll = torch.nn.functional.cross_entropy(
352
+ logits.view(-1, vocab_size),
353
+ target_ids.view(-1),
354
+ ignore_index=-100,
355
+ )
356
+
357
+ elif self.config.mixed_effects.mode == "random_slopes":
358
+ # Use participant-specific LoRA adapters
359
+ # Need to process each participant separately
360
+ losses = []
361
+ for j, pid in enumerate(batch_participant_ids):
362
+ # Get participant-specific decoder
363
+ participant_decoder = self.random_effects.get_slopes(
364
+ pid,
365
+ fixed_head=create_participant_lora_adapter(
366
+ self.base_decoder,
367
+ rank=self.config.lora_rank,
368
+ alpha=self.config.lora_alpha,
369
+ dropout=self.config.lora_dropout,
370
+ target_modules=self.config.lora_target_modules,
371
+ ),
372
+ create_if_missing=True,
373
+ )
374
+
375
+ # Get encoder outputs for this item
376
+ item_inputs = {k: v[j : j + 1] for k, v in inputs.items()}
377
+ encoder_outputs_j = self.encoder(**item_inputs)
378
+
379
+ # Run participant-specific decoder
380
+ decoder_outputs_j = participant_decoder(
381
+ input_ids=targets["input_ids"][j : j + 1],
382
+ encoder_hidden_states=encoder_outputs_j.last_hidden_state,
383
+ encoder_attention_mask=item_inputs["attention_mask"],
384
+ )
385
+
386
+ # Project to vocabulary
387
+ logits_j = self.lm_head(decoder_outputs_j.last_hidden_state)
388
+
389
+ # Compute loss for this item
390
+ loss_j = torch.nn.functional.cross_entropy(
391
+ logits_j.view(-1, vocab_size),
392
+ target_ids[j : j + 1].view(-1),
393
+ ignore_index=-100,
394
+ )
395
+ losses.append(loss_j)
396
+
397
+ loss_nll = torch.stack(losses).mean()
398
+
399
+ # Add prior regularization
400
+ loss_prior = self.random_effects.compute_prior_loss()
401
+ loss = loss_nll + loss_prior
402
+
403
+ optimizer.zero_grad()
404
+ loss.backward()
405
+ optimizer.step()
406
+
407
+ epoch_loss += loss.item()
408
+
409
+ epoch_loss = epoch_loss / n_batches
410
+
411
+ metrics: dict[str, float] = {
412
+ "train_loss": epoch_loss,
413
+ }
414
+
415
+ # Estimate variance components
416
+ if self.config.mixed_effects.estimate_variance_components:
417
+ var_comps = self.random_effects.estimate_variance_components()
418
+ if var_comps:
419
+ var_comp = var_comps.get("mu") or var_comps.get("slopes")
420
+ if var_comp:
421
+ if not hasattr(self, "variance_history"):
422
+ self.variance_history = []
423
+ self.variance_history.append(var_comp)
424
+ metrics["participant_variance"] = var_comp.variance
425
+ metrics["n_participants"] = var_comp.n_groups
426
+
427
+ # Compute training exact match
428
+ train_predictions = self._do_predict(items, participant_ids)
429
+ train_pred_texts = [p.predicted_class for p in train_predictions]
430
+ metrics["train_exact_match"] = self._compute_exact_match(
431
+ train_pred_texts, labels_numeric
432
+ )
433
+
434
+ if validation_items is not None and validation_labels_numeric is not None:
435
+ # Validation
436
+ if self.config.mixed_effects.mode == "fixed":
437
+ val_participant_ids = ["_fixed_"] * len(validation_items)
438
+ else:
439
+ val_participant_ids = ["_validation_"] * len(validation_items)
440
+ val_predictions = self._do_predict(validation_items, val_participant_ids)
441
+
442
+ val_pred_texts = [p.predicted_class for p in val_predictions]
443
+ metrics["val_exact_match"] = self._compute_exact_match(
444
+ val_pred_texts, validation_labels_numeric
445
+ )
446
+
447
+ return metrics
448
+
449
+ def _do_predict(
450
+ self, items: list[Item], participant_ids: list[str]
451
+ ) -> list[ModelPrediction]:
452
+ """Generate text for items with participant-specific random effects.
453
+
454
+ Parameters
455
+ ----------
456
+ items : list[Item]
457
+ Items to predict.
458
+ participant_ids : list[str]
459
+ Participant identifiers.
460
+
461
+ Returns
462
+ -------
463
+ list[ModelPrediction]
464
+ Predictions with predicted_class as generated text.
465
+ """
466
+ self.model.eval()
467
+
468
+ input_texts = self._prepare_inputs(items)
469
+
470
+ # Tokenize inputs
471
+ inputs = self.tokenizer(
472
+ input_texts,
473
+ padding=True,
474
+ truncation=True,
475
+ max_length=self.config.max_input_length,
476
+ return_tensors="pt",
477
+ ).to(self.config.device)
478
+
479
+ with torch.no_grad():
480
+ if self.config.mixed_effects.mode == "fixed":
481
+ # Standard generation
482
+ outputs = self.model.generate(
483
+ **inputs,
484
+ max_length=self.config.max_output_length,
485
+ num_beams=self.config.num_beams,
486
+ temperature=self.config.temperature,
487
+ top_p=self.config.top_p,
488
+ )
489
+ generated_texts = self.tokenizer.batch_decode(
490
+ outputs, skip_special_tokens=True
491
+ )
492
+
493
+ elif self.config.mixed_effects.mode == "random_intercepts":
494
+ # Generate with participant-specific bias
495
+ # For simplicity, use greedy decoding with bias applied at each step
496
+ # (Full beam search with bias is more complex)
497
+ generated_texts = []
498
+ vocab_size = self.lm_head.out_features
499
+
500
+ for i, pid in enumerate(participant_ids):
501
+ # Get encoder outputs for this item
502
+ item_inputs = {k: v[i : i + 1] for k, v in inputs.items()}
503
+ encoder_outputs = self.encoder(**item_inputs)
504
+
505
+ # Get participant bias
506
+ bias = self.random_effects.get_intercepts(
507
+ pid,
508
+ n_classes=vocab_size,
509
+ param_name="mu",
510
+ create_if_missing=False,
511
+ )
512
+
513
+ # Greedy decoding with bias
514
+ decoder_input_ids = torch.tensor(
515
+ [[self.tokenizer.pad_token_id]], device=self.config.device
516
+ )
517
+ generated_ids = []
518
+
519
+ for _ in range(self.config.max_output_length):
520
+ decoder_outputs = self.base_decoder(
521
+ input_ids=decoder_input_ids,
522
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
523
+ encoder_attention_mask=item_inputs["attention_mask"],
524
+ )
525
+ logits = self.lm_head(
526
+ decoder_outputs.last_hidden_state[:, -1, :]
527
+ )
528
+
529
+ # Add participant bias (bias is 1D, logits is 2D)
530
+ logits = logits + bias.unsqueeze(0)
531
+
532
+ # Greedy selection
533
+ next_token_id = torch.argmax(logits, dim=-1)
534
+ generated_ids.append(next_token_id.item())
535
+
536
+ # Stop if EOS
537
+ if next_token_id.item() == self.tokenizer.eos_token_id:
538
+ break
539
+
540
+ # Append to decoder input (scalar after argmax)
541
+ decoder_input_ids = torch.cat(
542
+ [decoder_input_ids, next_token_id.unsqueeze(-1)], dim=1
543
+ )
544
+
545
+ # Decode generated text
546
+ text = self.tokenizer.decode(
547
+ generated_ids, skip_special_tokens=True
548
+ )
549
+ generated_texts.append(text)
550
+
551
+ elif self.config.mixed_effects.mode == "random_slopes":
552
+ # Generate with participant-specific LoRA decoder
553
+ generated_texts = []
554
+
555
+ for i, pid in enumerate(participant_ids):
556
+ # Get participant-specific decoder
557
+ participant_decoder = self.random_effects.get_slopes(
558
+ pid,
559
+ fixed_head=create_participant_lora_adapter(
560
+ self.base_decoder,
561
+ rank=self.config.lora_rank,
562
+ alpha=self.config.lora_alpha,
563
+ dropout=self.config.lora_dropout,
564
+ target_modules=self.config.lora_target_modules,
565
+ ),
566
+ create_if_missing=False,
567
+ )
568
+
569
+ # Get encoder outputs
570
+ item_inputs = {k: v[i : i + 1] for k, v in inputs.items()}
571
+ encoder_outputs = self.encoder(**item_inputs)
572
+
573
+ # Greedy decoding with participant decoder
574
+ decoder_input_ids = torch.tensor(
575
+ [[self.tokenizer.pad_token_id]], device=self.config.device
576
+ )
577
+ generated_ids = []
578
+
579
+ for _ in range(self.config.max_output_length):
580
+ decoder_outputs = participant_decoder(
581
+ input_ids=decoder_input_ids,
582
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
583
+ encoder_attention_mask=item_inputs["attention_mask"],
584
+ )
585
+ logits = self.lm_head(
586
+ decoder_outputs.last_hidden_state[:, -1, :]
587
+ )
588
+
589
+ next_token_id = torch.argmax(logits, dim=-1)
590
+ generated_ids.append(next_token_id.item())
591
+
592
+ if next_token_id.item() == self.tokenizer.eos_token_id:
593
+ break
594
+
595
+ decoder_input_ids = torch.cat(
596
+ [decoder_input_ids, next_token_id.unsqueeze(-1)], dim=1
597
+ )
598
+
599
+ text = self.tokenizer.decode(
600
+ generated_ids, skip_special_tokens=True
601
+ )
602
+ generated_texts.append(text)
603
+
604
+ predictions = []
605
+ for i, item in enumerate(items):
606
+ predictions.append(
607
+ ModelPrediction(
608
+ item_id=str(item.id),
609
+ probabilities={}, # Not applicable for generation
610
+ predicted_class=generated_texts[i], # Generated text
611
+ confidence=1.0, # Not applicable for generation
612
+ )
613
+ )
614
+
615
+ return predictions
616
+
617
+ def _do_predict_proba(
618
+ self, items: list[Item], participant_ids: list[str]
619
+ ) -> np.ndarray:
620
+ """Predict probabilities (not applicable for free text generation).
621
+
622
+ For text generation, returns empty array.
623
+
624
+ Parameters
625
+ ----------
626
+ items : list[Item]
627
+ Items to predict.
628
+ participant_ids : list[str]
629
+ Participant identifiers.
630
+
631
+ Returns
632
+ -------
633
+ np.ndarray
634
+ Empty array of shape (n_items, 0).
635
+ """
636
+ return np.zeros((len(items), 0))
637
+
638
+ def _compute_exact_match(self, predictions: list[str], labels: list[str]) -> float:
639
+ """Compute exact match accuracy.
640
+
641
+ Parameters
642
+ ----------
643
+ predictions : list[str]
644
+ Predicted texts.
645
+ labels : list[str]
646
+ Ground truth texts.
647
+
648
+ Returns
649
+ -------
650
+ float
651
+ Exact match accuracy (fraction of exact matches).
652
+ """
653
+ return sum(
654
+ p.strip().lower() == label.strip().lower()
655
+ for p, label in zip(predictions, labels, strict=True)
656
+ ) / len(predictions)
657
+
658
+ def _save_model_components(self, save_path: Path) -> None:
659
+ """Save model-specific components (model, tokenizer).
660
+
661
+ Parameters
662
+ ----------
663
+ save_path : Path
664
+ Directory path to save the model.
665
+ """
666
+ self.model.save_pretrained(save_path / "model")
667
+ self.tokenizer.save_pretrained(save_path / "model")
668
+
669
+ def _load_model_components(self, load_path: Path) -> None:
670
+ """Load model-specific components (model, tokenizer).
671
+
672
+ Parameters
673
+ ----------
674
+ load_path : Path
675
+ Directory path to load the model from.
676
+ """
677
+ # Load config.json to reconstruct config
678
+ with open(load_path / "config.json") as f:
679
+ config_dict = json.load(f)
680
+
681
+ # Reconstruct MixedEffectsConfig if needed
682
+ if "mixed_effects" in config_dict and isinstance(
683
+ config_dict["mixed_effects"], dict
684
+ ):
685
+ from bead.active_learning.config import MixedEffectsConfig # noqa: PLC0415
686
+
687
+ config_dict["mixed_effects"] = MixedEffectsConfig(
688
+ **config_dict["mixed_effects"]
689
+ )
690
+
691
+ from bead.config.active_learning import FreeTextModelConfig # noqa: PLC0415
692
+
693
+ self.config = FreeTextModelConfig(**config_dict)
694
+
695
+ # Load model
696
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(load_path / "model")
697
+ self.tokenizer = AutoTokenizer.from_pretrained(load_path / "model")
698
+
699
+ # Re-extract components
700
+ self.encoder = self.model.get_encoder()
701
+ self.base_decoder = self.model.get_decoder()
702
+ self.lm_head = self.model.lm_head
703
+
704
+ self.model.to(self.config.device)
705
+
706
+ def _get_save_state(self) -> dict[str, object]:
707
+ """Get model-specific state to save in config.json.
708
+
709
+ Returns
710
+ -------
711
+ dict[str, object]
712
+ Model-specific state dictionary.
713
+ """
714
+ return {}
715
+
716
+ def _restore_training_state(self, config_dict: dict[str, object]) -> None:
717
+ """Restore model-specific training state from config_dict.
718
+
719
+ Parameters
720
+ ----------
721
+ config_dict : dict[str, object]
722
+ Configuration dictionary.
723
+ """
724
+ pass
725
+
726
+ def _get_n_classes_for_random_effects(self) -> int:
727
+ """Get the number of classes for initializing RandomEffectsManager.
728
+
729
+ For FreeTextModel, this is the vocabulary size.
730
+
731
+ Returns
732
+ -------
733
+ int
734
+ Vocabulary size.
735
+ """
736
+ return self.lm_head.out_features
737
+
738
+ def _initialize_random_effects(self, n_classes: int, **kwargs: object) -> None:
739
+ """Initialize the RandomEffectsManager.
740
+
741
+ Parameters
742
+ ----------
743
+ n_classes : int
744
+ Vocabulary size (for FreeTextModel).
745
+ **kwargs : object
746
+ Additional keyword arguments (not used).
747
+ """
748
+ self.random_effects = RandomEffectsManager(
749
+ self.config.mixed_effects,
750
+ vocab_size=n_classes,
751
+ )
752
+
753
+ def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
754
+ """Get the fixed head for random effects.
755
+
756
+ For FreeTextModel with random_slopes, returns a template adapter.
757
+ For other modes, returns None.
758
+
759
+ Returns
760
+ -------
761
+ torch.nn.Module | None
762
+ Template adapter for random_slopes, None otherwise.
763
+ """
764
+ if self.config.mixed_effects.mode == "random_slopes":
765
+ # For random_slopes, need to provide a template adapter
766
+ return create_participant_lora_adapter(
767
+ self.base_decoder,
768
+ rank=self.config.lora_rank,
769
+ alpha=self.config.lora_alpha,
770
+ dropout=self.config.lora_dropout,
771
+ target_modules=self.config.lora_target_modules,
772
+ )
773
+ return None