bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,956 @@
1
+ """Model for forced choice tasks (2AFC, 3AFC, 4AFC, nAFC)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers import AutoModel, AutoTokenizer, TrainingArguments
13
+
14
+ from bead.active_learning.config import MixedEffectsConfig, VarianceComponents
15
+ from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
16
+ from bead.active_learning.models.random_effects import RandomEffectsManager
17
+ from bead.active_learning.trainers.data_collator import MixedEffectsDataCollator
18
+ from bead.active_learning.trainers.dataset_utils import items_to_dataset
19
+ from bead.active_learning.trainers.metrics import compute_multiclass_metrics
20
+ from bead.active_learning.trainers.model_wrapper import EncoderClassifierWrapper
21
+ from bead.config.active_learning import ForcedChoiceModelConfig
22
+ from bead.items.item import Item
23
+ from bead.items.item_template import ItemTemplate, TaskType
24
+
25
+ __all__ = ["ForcedChoiceModel"]
26
+
27
+
28
+ class ForcedChoiceModel(ActiveLearningModel):
29
+ """Model for forced_choice tasks with n alternatives.
30
+
31
+ Supports 2AFC, 3AFC, 4AFC, and general nAFC tasks using any
32
+ HuggingFace transformer model. Provides two encoding strategies:
33
+ single encoder (concatenate options) or dual encoder (separate embeddings).
34
+
35
+ Parameters
36
+ ----------
37
+ config : ForcedChoiceModelConfig
38
+ Configuration object containing all model parameters.
39
+
40
+ Attributes
41
+ ----------
42
+ config : ForcedChoiceModelConfig
43
+ Model configuration.
44
+ tokenizer : AutoTokenizer
45
+ Transformer tokenizer.
46
+ encoder : AutoModel
47
+ Transformer encoder model.
48
+ classifier_head : nn.Sequential
49
+ Classification head (fixed effects head).
50
+ num_classes : int | None
51
+ Number of classes (inferred from training data).
52
+ option_names : list[str] | None
53
+ Option names (e.g., ["option_a", "option_b"]).
54
+ random_effects : RandomEffectsManager
55
+ Manager for participant-level random effects.
56
+ variance_history : list[VarianceComponents]
57
+ Variance component estimates over training (for diagnostics).
58
+ _is_fitted : bool
59
+ Whether model has been trained.
60
+
61
+ Examples
62
+ --------
63
+ >>> from uuid import uuid4
64
+ >>> from bead.items.item import Item
65
+ >>> from bead.config.active_learning import ForcedChoiceModelConfig
66
+ >>> items = [
67
+ ... Item(
68
+ ... item_template_id=uuid4(),
69
+ ... rendered_elements={"option_a": "sentence A", "option_b": "sentence B"}
70
+ ... )
71
+ ... for _ in range(10)
72
+ ... ]
73
+ >>> labels = ["option_a"] * 5 + ["option_b"] * 5
74
+ >>> config = ForcedChoiceModelConfig( # doctest: +SKIP
75
+ ... num_epochs=1, batch_size=2, device="cpu"
76
+ ... )
77
+ >>> model = ForcedChoiceModel(config=config) # doctest: +SKIP
78
+ >>> metrics = model.train(items, labels) # doctest: +SKIP
79
+ >>> predictions = model.predict(items[:3]) # doctest: +SKIP
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ config: ForcedChoiceModelConfig | None = None,
85
+ ) -> None:
86
+ """Initialize forced choice model.
87
+
88
+ Parameters
89
+ ----------
90
+ config : ForcedChoiceModelConfig | None
91
+ Configuration object. If None, uses default configuration.
92
+ """
93
+ self.config = config or ForcedChoiceModelConfig()
94
+
95
+ # Validate mixed_effects configuration
96
+ super().__init__(self.config)
97
+
98
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
99
+ self.encoder = AutoModel.from_pretrained(self.config.model_name)
100
+
101
+ self.num_classes: int | None = None
102
+ self.option_names: list[str] | None = None
103
+ self.classifier_head: nn.Sequential | None = None
104
+ self._is_fitted = False
105
+
106
+ # Initialize random effects manager
107
+ self.random_effects: RandomEffectsManager | None = None
108
+ self.variance_history: list[VarianceComponents] = []
109
+
110
+ self.encoder.to(self.config.device)
111
+
112
+ @property
113
+ def supported_task_types(self) -> list[TaskType]:
114
+ """Get supported task types.
115
+
116
+ Returns
117
+ -------
118
+ list[TaskType]
119
+ List containing "forced_choice".
120
+ """
121
+ return ["forced_choice"]
122
+
123
+ def validate_item_compatibility(
124
+ self, item: Item, item_template: ItemTemplate
125
+ ) -> None:
126
+ """Validate item is compatible with forced choice model.
127
+
128
+ Parameters
129
+ ----------
130
+ item : Item
131
+ Item to validate.
132
+ item_template : ItemTemplate
133
+ Template the item was constructed from.
134
+
135
+ Raises
136
+ ------
137
+ ValueError
138
+ If task_type is not "forced_choice".
139
+ ValueError
140
+ If task_spec.options is not defined.
141
+ ValueError
142
+ If item is missing required rendered_elements.
143
+ """
144
+ if item_template.task_type != "forced_choice":
145
+ raise ValueError(
146
+ f"Expected task_type 'forced_choice', got '{item_template.task_type}'"
147
+ )
148
+
149
+ if item_template.task_spec.options is None:
150
+ raise ValueError(
151
+ "task_spec.options must be defined for forced_choice tasks"
152
+ )
153
+
154
+ for option_name in item_template.task_spec.options:
155
+ if option_name not in item.rendered_elements:
156
+ raise ValueError(
157
+ f"Item missing required element '{option_name}' "
158
+ f"from rendered_elements"
159
+ )
160
+
161
+ def _initialize_classifier(self, num_classes: int) -> None:
162
+ """Initialize classification head for given number of classes.
163
+
164
+ Parameters
165
+ ----------
166
+ num_classes : int
167
+ Number of output classes.
168
+ """
169
+ hidden_size = self.encoder.config.hidden_size
170
+
171
+ if self.config.encoder_mode == "dual_encoder":
172
+ input_size = hidden_size * num_classes
173
+ else:
174
+ input_size = hidden_size
175
+
176
+ self.classifier_head = nn.Sequential(
177
+ nn.Linear(input_size, 256),
178
+ nn.ReLU(),
179
+ nn.Dropout(0.1),
180
+ nn.Linear(256, num_classes),
181
+ )
182
+ self.classifier_head.to(self.config.device)
183
+
184
+ def _encode_single(self, texts: list[str]) -> torch.Tensor:
185
+ """Encode texts using single encoder strategy.
186
+
187
+ Concatenates all option texts with [SEP] tokens and encodes once.
188
+
189
+ Parameters
190
+ ----------
191
+ texts : list[str]
192
+ List of concatenated option texts for each item.
193
+
194
+ Returns
195
+ -------
196
+ torch.Tensor
197
+ Encoded representations of shape (batch_size, hidden_size).
198
+ """
199
+ encodings = self.tokenizer(
200
+ texts,
201
+ padding=True,
202
+ truncation=True,
203
+ max_length=self.config.max_length,
204
+ return_tensors="pt",
205
+ )
206
+ encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
207
+
208
+ outputs = self.encoder(**encodings)
209
+ return outputs.last_hidden_state[:, 0, :]
210
+
211
+ def _encode_dual(self, options_per_item: list[list[str]]) -> torch.Tensor:
212
+ """Encode texts using dual encoder strategy.
213
+
214
+ Encodes each option separately and concatenates embeddings.
215
+
216
+ Parameters
217
+ ----------
218
+ options_per_item : list[list[str]]
219
+ List of option lists. Each inner list contains option texts for one item.
220
+
221
+ Returns
222
+ -------
223
+ torch.Tensor
224
+ Concatenated encodings of shape (batch_size, hidden_size * num_options).
225
+ """
226
+ all_embeddings = []
227
+
228
+ for options in options_per_item:
229
+ option_embeddings = []
230
+ for option_text in options:
231
+ encodings = self.tokenizer(
232
+ [option_text],
233
+ padding=True,
234
+ truncation=True,
235
+ max_length=self.config.max_length,
236
+ return_tensors="pt",
237
+ )
238
+ encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
239
+
240
+ outputs = self.encoder(**encodings)
241
+ cls_embedding = outputs.last_hidden_state[0, 0, :]
242
+ option_embeddings.append(cls_embedding)
243
+
244
+ concatenated = torch.cat(option_embeddings, dim=0)
245
+ all_embeddings.append(concatenated)
246
+
247
+ return torch.stack(all_embeddings)
248
+
249
+ def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
250
+ """Prepare inputs for encoding based on encoder mode.
251
+
252
+ Parameters
253
+ ----------
254
+ items : list[Item]
255
+ Items to encode.
256
+
257
+ Returns
258
+ -------
259
+ torch.Tensor
260
+ Encoded representations.
261
+ """
262
+ if self.option_names is None:
263
+ raise ValueError("Model not initialized. Call train() first.")
264
+
265
+ if self.config.encoder_mode == "single_encoder":
266
+ texts = []
267
+ for item in items:
268
+ option_texts = [
269
+ item.rendered_elements.get(opt, "") for opt in self.option_names
270
+ ]
271
+ concatenated = " [SEP] ".join(option_texts)
272
+ texts.append(concatenated)
273
+ return self._encode_single(texts)
274
+ else:
275
+ options_per_item = []
276
+ for item in items:
277
+ option_texts = [
278
+ item.rendered_elements.get(opt, "") for opt in self.option_names
279
+ ]
280
+ options_per_item.append(option_texts)
281
+ return self._encode_dual(options_per_item)
282
+
283
+ def _validate_labels(self, labels: list[str]) -> None:
284
+ """Validate that all labels are valid option names.
285
+
286
+ Parameters
287
+ ----------
288
+ labels : list[str]
289
+ Labels to validate.
290
+
291
+ Raises
292
+ ------
293
+ ValueError
294
+ If any label is not in option_names.
295
+ """
296
+ if self.option_names is None:
297
+ raise ValueError("option_names not initialized")
298
+
299
+ valid_labels = set(self.option_names)
300
+ invalid = [label for label in labels if label not in valid_labels]
301
+ if invalid:
302
+ raise ValueError(
303
+ f"Invalid labels found: {set(invalid)}. "
304
+ f"Labels must be one of {valid_labels}."
305
+ )
306
+
307
+ def _prepare_training_data(
308
+ self,
309
+ items: list[Item],
310
+ labels: list[str],
311
+ participant_ids: list[str],
312
+ validation_items: list[Item] | None,
313
+ validation_labels: list[str] | None,
314
+ ) -> tuple[list[Item], list[int], list[str], list[Item] | None, list[int] | None]:
315
+ """Prepare training data for forced choice model.
316
+
317
+ Parameters
318
+ ----------
319
+ items : list[Item]
320
+ Training items.
321
+ labels : list[str]
322
+ Training labels (option names).
323
+ participant_ids : list[str]
324
+ Normalized participant IDs.
325
+ validation_items : list[Item] | None
326
+ Validation items.
327
+ validation_labels : list[str] | None
328
+ Validation labels.
329
+
330
+ Returns
331
+ -------
332
+ tuple[list[Item], list[int], list[str], list[Item] | None, list[int] | None]
333
+ Prepared items, numeric labels, participant_ids, validation_items,
334
+ numeric validation_labels.
335
+ """
336
+ unique_labels = sorted(set(labels))
337
+ self.num_classes = len(unique_labels)
338
+ self.option_names = unique_labels
339
+
340
+ self._validate_labels(labels)
341
+ self._initialize_classifier(self.num_classes)
342
+
343
+ label_to_idx = {label: idx for idx, label in enumerate(self.option_names)}
344
+ y_numeric = [label_to_idx[label] for label in labels]
345
+
346
+ # Convert validation labels if provided
347
+ val_y_numeric = None
348
+ if validation_items is not None and validation_labels is not None:
349
+ self._validate_labels(validation_labels)
350
+ if len(validation_items) != len(validation_labels):
351
+ raise ValueError(
352
+ f"Number of validation items ({len(validation_items)}) "
353
+ f"must match number of validation labels ({len(validation_labels)})"
354
+ )
355
+ val_y_numeric = [label_to_idx[label] for label in validation_labels]
356
+
357
+ return items, y_numeric, participant_ids, validation_items, val_y_numeric
358
+
359
+ def _initialize_random_effects(self, n_classes: int) -> None:
360
+ """Initialize random effects manager.
361
+
362
+ Parameters
363
+ ----------
364
+ n_classes : int
365
+ Number of classes.
366
+ """
367
+ self.random_effects = RandomEffectsManager(
368
+ self.config.mixed_effects, n_classes=n_classes
369
+ )
370
+
371
+ def _do_training(
372
+ self,
373
+ items: list[Item],
374
+ labels_numeric: list[int],
375
+ participant_ids: list[str],
376
+ validation_items: list[Item] | None,
377
+ validation_labels_numeric: list[int] | None,
378
+ ) -> dict[str, float]:
379
+ """Perform forced choice model training.
380
+
381
+ Parameters
382
+ ----------
383
+ items : list[Item]
384
+ Training items.
385
+ labels_numeric : list[int]
386
+ Numeric labels (class indices).
387
+ participant_ids : list[str]
388
+ Participant IDs.
389
+ validation_items : list[Item] | None
390
+ Validation items.
391
+ validation_labels_numeric : list[int] | None
392
+ Numeric validation labels.
393
+
394
+ Returns
395
+ -------
396
+ dict[str, float]
397
+ Training metrics.
398
+ """
399
+ # Convert validation_labels_numeric back to string labels for validation metrics
400
+ validation_labels = None
401
+ if validation_items is not None and validation_labels_numeric is not None:
402
+ validation_labels = [
403
+ self.option_names[label_idx] for label_idx in validation_labels_numeric
404
+ ]
405
+
406
+ # Use HuggingFace Trainer for fixed and random_intercepts modes
407
+ if self.config.mixed_effects.mode in ("fixed", "random_intercepts"):
408
+ metrics = self._train_with_huggingface_trainer(
409
+ items=items,
410
+ y_numeric=labels_numeric,
411
+ participant_ids=participant_ids,
412
+ validation_items=validation_items,
413
+ validation_labels=validation_labels,
414
+ )
415
+ else:
416
+ # Use custom loop for random_slopes mode
417
+ metrics = self._train_with_custom_loop(
418
+ items=items,
419
+ y_numeric=labels_numeric,
420
+ participant_ids=participant_ids,
421
+ validation_items=validation_items,
422
+ validation_labels=validation_labels,
423
+ )
424
+
425
+ # Add validation accuracy if validation data provided and not already computed
426
+ if (
427
+ validation_items is not None
428
+ and validation_labels is not None
429
+ and "val_accuracy" not in metrics
430
+ ):
431
+ # Validation with placeholder participant_ids for mixed effects
432
+ if self.config.mixed_effects.mode == "fixed":
433
+ val_participant_ids = ["_fixed_"] * len(validation_items)
434
+ else:
435
+ val_participant_ids = ["_validation_"] * len(validation_items)
436
+ val_predictions = self._do_predict(validation_items, val_participant_ids)
437
+ val_pred_labels = [p.predicted_class for p in val_predictions]
438
+ val_acc = sum(
439
+ pred == true
440
+ for pred, true in zip(val_pred_labels, validation_labels, strict=True)
441
+ ) / len(validation_labels)
442
+ metrics["val_accuracy"] = val_acc
443
+
444
+ return metrics
445
+
446
+ def _train_with_huggingface_trainer(
447
+ self,
448
+ items: list[Item],
449
+ y_numeric: list[int],
450
+ participant_ids: list[str],
451
+ validation_items: list[Item] | None,
452
+ validation_labels: list[str] | None,
453
+ ) -> dict[str, float]:
454
+ """Train using HuggingFace Trainer with mixed effects support.
455
+
456
+ Parameters
457
+ ----------
458
+ items : list[Item]
459
+ Training items.
460
+ y_numeric : list[int]
461
+ Numeric labels (class indices).
462
+ participant_ids : list[str]
463
+ Participant IDs.
464
+ validation_items : list[Item] | None
465
+ Validation items.
466
+ validation_labels : list[str] | None
467
+ Validation labels.
468
+
469
+ Returns
470
+ -------
471
+ dict[str, float]
472
+ Training metrics.
473
+ """
474
+ # Convert items to HuggingFace Dataset
475
+ train_dataset = items_to_dataset(
476
+ items=items,
477
+ labels=y_numeric,
478
+ participant_ids=participant_ids,
479
+ tokenizer=self.tokenizer,
480
+ max_length=self.config.max_length,
481
+ )
482
+
483
+ # Create validation dataset if provided
484
+ eval_dataset = None
485
+ if validation_items is not None and validation_labels is not None:
486
+ label_to_idx = {label: idx for idx, label in enumerate(self.option_names)}
487
+ val_y_numeric = [label_to_idx[label] for label in validation_labels]
488
+ val_participant_ids = (
489
+ ["_validation_"] * len(validation_items)
490
+ if self.config.mixed_effects.mode != "fixed"
491
+ else ["_fixed_"] * len(validation_items)
492
+ )
493
+ eval_dataset = items_to_dataset(
494
+ items=validation_items,
495
+ labels=val_y_numeric,
496
+ participant_ids=val_participant_ids,
497
+ tokenizer=self.tokenizer,
498
+ max_length=self.config.max_length,
499
+ )
500
+
501
+ # Create wrapper model for Trainer
502
+ wrapped_model = EncoderClassifierWrapper(
503
+ encoder=self.encoder, classifier_head=self.classifier_head
504
+ )
505
+
506
+ # Create data collator
507
+ data_collator = MixedEffectsDataCollator(tokenizer=self.tokenizer)
508
+
509
+ # Create metrics computation function
510
+ def compute_metrics_fn(eval_pred: object) -> dict[str, float]:
511
+ return compute_multiclass_metrics(eval_pred, num_labels=self.num_classes)
512
+
513
+ # Create training arguments with checkpointing
514
+ with tempfile.TemporaryDirectory() as tmpdir:
515
+ checkpoint_dir = Path(tmpdir) / "checkpoints"
516
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
517
+
518
+ training_args = TrainingArguments(
519
+ output_dir=str(checkpoint_dir),
520
+ num_train_epochs=self.config.num_epochs,
521
+ per_device_train_batch_size=self.config.batch_size,
522
+ per_device_eval_batch_size=self.config.batch_size,
523
+ learning_rate=self.config.learning_rate,
524
+ logging_steps=10,
525
+ eval_strategy="epoch" if eval_dataset is not None else "no",
526
+ save_strategy="epoch",
527
+ save_total_limit=1,
528
+ load_best_model_at_end=False,
529
+ report_to="none",
530
+ remove_unused_columns=False,
531
+ use_cpu=self.config.device == "cpu",
532
+ )
533
+
534
+ # Import here to avoid circular import
535
+ from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
536
+ MixedEffectsTrainer,
537
+ )
538
+
539
+ # Create trainer
540
+ trainer = MixedEffectsTrainer(
541
+ model=wrapped_model,
542
+ args=training_args,
543
+ train_dataset=train_dataset,
544
+ eval_dataset=eval_dataset,
545
+ data_collator=data_collator,
546
+ tokenizer=self.tokenizer,
547
+ random_effects_manager=self.random_effects,
548
+ compute_metrics=compute_metrics_fn,
549
+ )
550
+
551
+ # Train
552
+ train_result = trainer.train()
553
+
554
+ # Get training metrics
555
+ train_metrics = trainer.evaluate(eval_dataset=train_dataset)
556
+ metrics: dict[str, float] = {
557
+ "train_loss": float(train_result.training_loss),
558
+ "train_accuracy": train_metrics.get("eval_accuracy", 0.0),
559
+ "train_precision": train_metrics.get("eval_precision", 0.0),
560
+ "train_recall": train_metrics.get("eval_recall", 0.0),
561
+ "train_f1": train_metrics.get("eval_f1", 0.0),
562
+ }
563
+
564
+ # Get validation metrics if eval_dataset was provided
565
+ if eval_dataset is not None:
566
+ val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
567
+ metrics.update(
568
+ {
569
+ "val_accuracy": val_metrics.get("eval_accuracy", 0.0),
570
+ "val_precision": val_metrics.get("eval_precision", 0.0),
571
+ "val_recall": val_metrics.get("eval_recall", 0.0),
572
+ "val_f1": val_metrics.get("eval_f1", 0.0),
573
+ }
574
+ )
575
+
576
+ # Estimate variance components
577
+ if self.config.mixed_effects.estimate_variance_components:
578
+ var_comps = self.random_effects.estimate_variance_components()
579
+ if var_comps:
580
+ var_comp = var_comps.get("mu") or var_comps.get("slopes")
581
+ if var_comp:
582
+ self.variance_history.append(var_comp)
583
+ metrics["participant_variance"] = var_comp.variance
584
+ metrics["n_participants"] = var_comp.n_groups
585
+
586
+ return metrics
587
+
588
+ def _train_with_custom_loop(
589
+ self,
590
+ items: list[Item],
591
+ y_numeric: list[int],
592
+ participant_ids: list[str],
593
+ validation_items: list[Item] | None,
594
+ validation_labels: list[str] | None,
595
+ ) -> dict[str, float]:
596
+ """Train using custom training loop (for random_slopes mode).
597
+
598
+ Parameters
599
+ ----------
600
+ items : list[Item]
601
+ Training items.
602
+ y_numeric : list[int]
603
+ Numeric labels (class indices).
604
+ participant_ids : list[str]
605
+ Participant IDs.
606
+ validation_items : list[Item] | None
607
+ Validation items.
608
+ validation_labels : list[str] | None
609
+ Validation labels.
610
+
611
+ Returns
612
+ -------
613
+ dict[str, float]
614
+ Training metrics.
615
+ """
616
+ # Convert to tensor
617
+ y = torch.tensor(y_numeric, dtype=torch.long, device=self.config.device)
618
+
619
+ # Build optimizer parameters
620
+ params_to_optimize = list(self.encoder.parameters()) + list(
621
+ self.classifier_head.parameters()
622
+ )
623
+
624
+ # Add random effects parameters (for random_slopes)
625
+ if self.config.mixed_effects.mode == "random_slopes":
626
+ for head in self.random_effects.slopes.values():
627
+ params_to_optimize.extend(head.parameters())
628
+
629
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
630
+ criterion = nn.CrossEntropyLoss()
631
+
632
+ self.encoder.train()
633
+ self.classifier_head.train()
634
+
635
+ for _epoch in range(self.config.num_epochs):
636
+ n_batches = (
637
+ len(items) + self.config.batch_size - 1
638
+ ) // self.config.batch_size
639
+ epoch_loss = 0.0
640
+ epoch_correct = 0
641
+
642
+ for i in range(n_batches):
643
+ start_idx = i * self.config.batch_size
644
+ end_idx = min(start_idx + self.config.batch_size, len(items))
645
+
646
+ batch_items = items[start_idx:end_idx]
647
+ batch_labels = y[start_idx:end_idx]
648
+ batch_participant_ids = participant_ids[start_idx:end_idx]
649
+
650
+ embeddings = self._prepare_inputs(batch_items)
651
+
652
+ # Forward pass depends on mixed effects mode
653
+ if self.config.mixed_effects.mode == "fixed":
654
+ # Standard forward pass
655
+ logits = self.classifier_head(embeddings)
656
+
657
+ elif self.config.mixed_effects.mode == "random_intercepts":
658
+ # Fixed head + per-participant bias
659
+ logits = self.classifier_head(embeddings)
660
+ for j, pid in enumerate(batch_participant_ids):
661
+ bias = self.random_effects.get_intercepts(
662
+ pid,
663
+ n_classes=self.num_classes,
664
+ param_name="mu",
665
+ create_if_missing=True,
666
+ )
667
+ logits[j] = logits[j] + bias
668
+
669
+ elif self.config.mixed_effects.mode == "random_slopes":
670
+ # Per-participant head
671
+ logits_list = []
672
+ for j, pid in enumerate(batch_participant_ids):
673
+ participant_head = self.random_effects.get_slopes(
674
+ pid,
675
+ fixed_head=self.classifier_head,
676
+ create_if_missing=True,
677
+ )
678
+ logits_j = participant_head(embeddings[j : j + 1])
679
+ logits_list.append(logits_j)
680
+ logits = torch.cat(logits_list, dim=0)
681
+
682
+ # Data loss + prior regularization
683
+ loss_ce = criterion(logits, batch_labels)
684
+ loss_prior = self.random_effects.compute_prior_loss()
685
+ loss = loss_ce + loss_prior
686
+
687
+ optimizer.zero_grad()
688
+ loss.backward()
689
+ optimizer.step()
690
+
691
+ epoch_loss += loss.item()
692
+ predictions = torch.argmax(logits, dim=1)
693
+ epoch_correct += (predictions == batch_labels).sum().item()
694
+
695
+ epoch_acc = epoch_correct / len(items)
696
+ epoch_loss = epoch_loss / n_batches
697
+
698
+ self._is_fitted = True
699
+
700
+ metrics: dict[str, float] = {
701
+ "train_accuracy": epoch_acc,
702
+ "train_loss": epoch_loss,
703
+ }
704
+
705
+ # Estimate variance components
706
+ if self.config.mixed_effects.estimate_variance_components:
707
+ var_comps = self.random_effects.estimate_variance_components()
708
+ if var_comps:
709
+ var_comp = var_comps.get("mu") or var_comps.get("slopes")
710
+ if var_comp:
711
+ self.variance_history.append(var_comp)
712
+ metrics["participant_variance"] = var_comp.variance
713
+ metrics["n_participants"] = var_comp.n_groups
714
+
715
+ if validation_items is not None and validation_labels is not None:
716
+ self._validate_labels(validation_labels)
717
+
718
+ if len(validation_items) != len(validation_labels):
719
+ raise ValueError(
720
+ f"Number of validation items ({len(validation_items)}) "
721
+ f"must match number of validation labels ({len(validation_labels)})"
722
+ )
723
+
724
+ # Validation with placeholder participant_ids for mixed effects
725
+ if self.config.mixed_effects.mode == "fixed":
726
+ val_predictions = self.predict(validation_items, participant_ids=None)
727
+ else:
728
+ val_participant_ids = ["_validation_"] * len(validation_items)
729
+ val_predictions = self.predict(
730
+ validation_items, participant_ids=val_participant_ids
731
+ )
732
+ val_pred_labels = [p.predicted_class for p in val_predictions]
733
+ val_acc = sum(
734
+ pred == true
735
+ for pred, true in zip(val_pred_labels, validation_labels, strict=True)
736
+ ) / len(validation_labels)
737
+ metrics["val_accuracy"] = val_acc
738
+
739
+ return metrics
740
+
741
+ def _do_predict(
742
+ self, items: list[Item], participant_ids: list[str]
743
+ ) -> list[ModelPrediction]:
744
+ """Perform forced choice model prediction.
745
+
746
+ Parameters
747
+ ----------
748
+ items : list[Item]
749
+ Items to predict.
750
+ participant_ids : list[str]
751
+ Normalized participant IDs.
752
+
753
+ Returns
754
+ -------
755
+ list[ModelPrediction]
756
+ Predictions.
757
+ """
758
+ self.encoder.eval()
759
+ self.classifier_head.eval()
760
+
761
+ with torch.no_grad():
762
+ embeddings = self._prepare_inputs(items)
763
+
764
+ # Forward pass depends on mixed effects mode
765
+ if self.config.mixed_effects.mode == "fixed":
766
+ logits = self.classifier_head(embeddings)
767
+
768
+ elif self.config.mixed_effects.mode == "random_intercepts":
769
+ logits = self.classifier_head(embeddings)
770
+ for i, pid in enumerate(participant_ids):
771
+ # Unknown participants: use prior mean (zero bias)
772
+ bias = self.random_effects.get_intercepts(
773
+ pid,
774
+ n_classes=self.num_classes,
775
+ param_name="mu",
776
+ create_if_missing=False,
777
+ )
778
+ logits[i] = logits[i] + bias
779
+
780
+ elif self.config.mixed_effects.mode == "random_slopes":
781
+ logits_list = []
782
+ for i, pid in enumerate(participant_ids):
783
+ # Unknown participants: use fixed head
784
+ participant_head = self.random_effects.get_slopes(
785
+ pid, fixed_head=self.classifier_head, create_if_missing=False
786
+ )
787
+ logits_i = participant_head(embeddings[i : i + 1])
788
+ logits_list.append(logits_i)
789
+ logits = torch.cat(logits_list, dim=0)
790
+
791
+ proba = torch.softmax(logits, dim=1).cpu().numpy()
792
+ pred_classes = torch.argmax(logits, dim=1).cpu().numpy()
793
+
794
+ predictions = []
795
+ for i, item in enumerate(items):
796
+ pred_label = self.option_names[pred_classes[i]]
797
+ prob_dict = {
798
+ opt: float(proba[i, idx]) for idx, opt in enumerate(self.option_names)
799
+ }
800
+ predictions.append(
801
+ ModelPrediction(
802
+ item_id=str(item.id),
803
+ probabilities=prob_dict,
804
+ predicted_class=pred_label,
805
+ confidence=float(proba[i, pred_classes[i]]),
806
+ )
807
+ )
808
+
809
+ return predictions
810
+
811
+ def _do_predict_proba(
812
+ self, items: list[Item], participant_ids: list[str]
813
+ ) -> np.ndarray:
814
+ """Perform forced choice model probability prediction.
815
+
816
+ Parameters
817
+ ----------
818
+ items : list[Item]
819
+ Items to predict.
820
+ participant_ids : list[str]
821
+ Normalized participant IDs.
822
+
823
+ Returns
824
+ -------
825
+ np.ndarray
826
+ Probability array of shape (n_items, n_classes).
827
+ """
828
+ self.encoder.eval()
829
+ self.classifier_head.eval()
830
+
831
+ with torch.no_grad():
832
+ embeddings = self._prepare_inputs(items)
833
+
834
+ # Forward pass depends on mixed effects mode
835
+ if self.config.mixed_effects.mode == "fixed":
836
+ logits = self.classifier_head(embeddings)
837
+
838
+ elif self.config.mixed_effects.mode == "random_intercepts":
839
+ logits = self.classifier_head(embeddings)
840
+ for i, pid in enumerate(participant_ids):
841
+ bias = self.random_effects.get_intercepts(
842
+ pid,
843
+ n_classes=self.num_classes,
844
+ param_name="mu",
845
+ create_if_missing=False,
846
+ )
847
+ logits[i] = logits[i] + bias
848
+
849
+ elif self.config.mixed_effects.mode == "random_slopes":
850
+ logits_list = []
851
+ for i, pid in enumerate(participant_ids):
852
+ participant_head = self.random_effects.get_slopes(
853
+ pid, fixed_head=self.classifier_head, create_if_missing=False
854
+ )
855
+ logits_i = participant_head(embeddings[i : i + 1])
856
+ logits_list.append(logits_i)
857
+ logits = torch.cat(logits_list, dim=0)
858
+
859
+ proba = torch.softmax(logits, dim=1).cpu().numpy()
860
+
861
+ return proba
862
+
863
+ def _save_model_components(self, save_path: Path) -> None:
864
+ """Save model-specific components.
865
+
866
+ Parameters
867
+ ----------
868
+ save_path : Path
869
+ Directory to save to.
870
+ """
871
+ self.encoder.save_pretrained(save_path / "encoder")
872
+ self.tokenizer.save_pretrained(save_path / "encoder")
873
+
874
+ torch.save(
875
+ self.classifier_head.state_dict(),
876
+ save_path / "classifier_head.pt",
877
+ )
878
+
879
+ def _get_save_state(self) -> dict[str, object]:
880
+ """Get model-specific state to save.
881
+
882
+ Returns
883
+ -------
884
+ dict[str, object]
885
+ State dictionary.
886
+ """
887
+ return {
888
+ "num_classes": self.num_classes,
889
+ "option_names": self.option_names,
890
+ }
891
+
892
+ def _restore_training_state(self, config_dict: dict[str, object]) -> None:
893
+ """Restore model-specific training state.
894
+
895
+ Parameters
896
+ ----------
897
+ config_dict : dict[str, object]
898
+ Configuration dictionary with training state.
899
+ """
900
+ self.num_classes = config_dict.pop("num_classes")
901
+ self.option_names = config_dict.pop("option_names")
902
+
903
+ def _load_model_components(self, load_path: Path) -> None:
904
+ """Load model-specific components.
905
+
906
+ Parameters
907
+ ----------
908
+ load_path : Path
909
+ Directory to load from.
910
+ """
911
+ # Load config.json to reconstruct config
912
+ with open(load_path / "config.json") as f:
913
+ config_dict = json.load(f)
914
+
915
+ # Reconstruct MixedEffectsConfig if needed
916
+ if "mixed_effects" in config_dict and isinstance(
917
+ config_dict["mixed_effects"], dict
918
+ ):
919
+ config_dict["mixed_effects"] = MixedEffectsConfig(
920
+ **config_dict["mixed_effects"]
921
+ )
922
+
923
+ self.config = ForcedChoiceModelConfig(**config_dict)
924
+
925
+ self.encoder = AutoModel.from_pretrained(load_path / "encoder")
926
+ self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
927
+
928
+ self._initialize_classifier(self.num_classes)
929
+ self.classifier_head.load_state_dict(
930
+ torch.load(
931
+ load_path / "classifier_head.pt", map_location=self.config.device
932
+ )
933
+ )
934
+
935
+ self.encoder.to(self.config.device)
936
+ self.classifier_head.to(self.config.device)
937
+
938
+ def _get_n_classes_for_random_effects(self) -> int:
939
+ """Get the number of classes for initializing RandomEffectsManager.
940
+
941
+ Returns
942
+ -------
943
+ int
944
+ Number of classes.
945
+ """
946
+ return self.num_classes
947
+
948
+ def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
949
+ """Get the fixed head for random effects.
950
+
951
+ Returns
952
+ -------
953
+ torch.nn.Module | None
954
+ The classifier head, or None if not applicable.
955
+ """
956
+ return self.classifier_head