bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,910 @@
1
+ """Binary model for yes/no or true/false judgments.
2
+
3
+ Expected architecture: Binary classification with 2-class output.
4
+ Different from 2AFC in semantics - represents absolute judgment rather than choice.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import tempfile
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from transformers import AutoModel, AutoTokenizer, TrainingArguments
16
+
17
+ from bead.active_learning.config import VarianceComponents
18
+ from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
19
+ from bead.active_learning.models.random_effects import RandomEffectsManager
20
+ from bead.active_learning.trainers.data_collator import MixedEffectsDataCollator
21
+ from bead.active_learning.trainers.dataset_utils import items_to_dataset
22
+ from bead.active_learning.trainers.metrics import compute_binary_metrics
23
+ from bead.active_learning.trainers.model_wrapper import EncoderClassifierWrapper
24
+ from bead.config.active_learning import BinaryModelConfig
25
+ from bead.items.item import Item
26
+ from bead.items.item_template import ItemTemplate, TaskType
27
+
28
+ __all__ = ["BinaryModel"]
29
+
30
+
31
+ class BinaryModel(ActiveLearningModel):
32
+ """Model for binary tasks (yes/no, true/false judgments).
33
+
34
+ Uses true binary classification with a single output unit and sigmoid
35
+ activation (logistic regression). This is more efficient than using
36
+ 2-class softmax, as we only need to output P(y=1) and compute
37
+ P(y=0) = 1 - P(y=1).
38
+
39
+ Parameters
40
+ ----------
41
+ config : BinaryModelConfig
42
+ Configuration object containing all model parameters.
43
+
44
+ Attributes
45
+ ----------
46
+ config : BinaryModelConfig
47
+ Model configuration.
48
+ tokenizer : AutoTokenizer
49
+ Transformer tokenizer.
50
+ encoder : AutoModel
51
+ Transformer encoder model.
52
+ classifier_head : nn.Sequential
53
+ Classification head (fixed effects head) - outputs single logit.
54
+ num_classes : int
55
+ Number of output units (always 1 for binary classification).
56
+ label_names : list[str] | None
57
+ Label names (e.g., ["no", "yes"] sorted alphabetically).
58
+ positive_class : str | None
59
+ Which label corresponds to y=1 (second alphabetically).
60
+ random_effects : RandomEffectsManager
61
+ Manager for participant-level random effects (scalar biases).
62
+ variance_history : list[VarianceComponents]
63
+ Variance component estimates over training (for diagnostics).
64
+ _is_fitted : bool
65
+ Whether model has been trained.
66
+
67
+ Examples
68
+ --------
69
+ >>> from uuid import uuid4
70
+ >>> from bead.items.item import Item
71
+ >>> from bead.config.active_learning import BinaryModelConfig
72
+ >>> items = [
73
+ ... Item(
74
+ ... item_template_id=uuid4(),
75
+ ... rendered_elements={"text": f"Sentence {i}"}
76
+ ... )
77
+ ... for i in range(10)
78
+ ... ]
79
+ >>> labels = ["yes"] * 5 + ["no"] * 5
80
+ >>> config = BinaryModelConfig( # doctest: +SKIP
81
+ ... num_epochs=1, batch_size=2, device="cpu"
82
+ ... )
83
+ >>> model = BinaryModel(config=config) # doctest: +SKIP
84
+ >>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
85
+ >>> predictions = model.predict(items[:3], participant_ids=None) # doctest: +SKIP
86
+
87
+ Notes
88
+ -----
89
+ This model uses BCEWithLogitsLoss instead of CrossEntropyLoss, and applies
90
+ sigmoid activation to get probabilities. Random intercepts are scalar values
91
+ (1-dimensional) that shift the logit for each participant.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ config: BinaryModelConfig | None = None,
97
+ ) -> None:
98
+ """Initialize binary model.
99
+
100
+ Parameters
101
+ ----------
102
+ config : BinaryModelConfig | None
103
+ Configuration object. If None, uses default configuration.
104
+ """
105
+ self.config = config or BinaryModelConfig()
106
+
107
+ # Validate mixed_effects configuration
108
+ super().__init__(self.config)
109
+
110
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
111
+ self.encoder = AutoModel.from_pretrained(self.config.model_name)
112
+
113
+ self.num_classes: int = 1 # Single output unit for binary classification
114
+ self.label_names: list[str] | None = None
115
+ self.positive_class: str | None = None # Which label corresponds to 1
116
+ self.classifier_head: nn.Sequential | None = None
117
+ self._is_fitted = False
118
+
119
+ # Initialize random effects manager
120
+ self.random_effects: RandomEffectsManager | None = None
121
+ self.variance_history: list[VarianceComponents] = []
122
+
123
+ self.encoder.to(self.config.device)
124
+
125
+ @property
126
+ def supported_task_types(self) -> list[TaskType]:
127
+ """Get supported task types.
128
+
129
+ Returns
130
+ -------
131
+ list[TaskType]
132
+ List containing "binary".
133
+ """
134
+ return ["binary"]
135
+
136
+ def validate_item_compatibility(
137
+ self, item: Item, item_template: ItemTemplate
138
+ ) -> None:
139
+ """Validate item is compatible with binary model.
140
+
141
+ Parameters
142
+ ----------
143
+ item : Item
144
+ Item to validate.
145
+ item_template : ItemTemplate
146
+ Template the item was constructed from.
147
+
148
+ Raises
149
+ ------
150
+ ValueError
151
+ If task_type is not "binary".
152
+ """
153
+ if item_template.task_type != "binary":
154
+ raise ValueError(
155
+ f"Expected task_type 'binary', got '{item_template.task_type}'"
156
+ )
157
+
158
+ def _initialize_classifier(self) -> None:
159
+ """Initialize classification head for binary classification.
160
+
161
+ Outputs a single value (logit) for sigmoid activation.
162
+ """
163
+ hidden_size = self.encoder.config.hidden_size
164
+
165
+ # Single output unit for binary classification
166
+ if self.config.encoder_mode == "dual_encoder":
167
+ input_size = hidden_size * 2
168
+ else:
169
+ input_size = hidden_size
170
+
171
+ self.classifier_head = nn.Sequential(
172
+ nn.Linear(input_size, 256),
173
+ nn.ReLU(),
174
+ nn.Dropout(0.1),
175
+ nn.Linear(256, 1), # Single output unit
176
+ )
177
+ self.classifier_head.to(self.config.device)
178
+
179
+ def _encode_texts(self, texts: list[str]) -> torch.Tensor:
180
+ """Encode texts using single encoder.
181
+
182
+ Parameters
183
+ ----------
184
+ texts : list[str]
185
+ Texts to encode.
186
+
187
+ Returns
188
+ -------
189
+ torch.Tensor
190
+ Encoded representations of shape (batch_size, hidden_size).
191
+ """
192
+ encodings = self.tokenizer(
193
+ texts,
194
+ padding=True,
195
+ truncation=True,
196
+ max_length=self.config.max_length,
197
+ return_tensors="pt",
198
+ )
199
+ encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
200
+
201
+ outputs = self.encoder(**encodings)
202
+ return outputs.last_hidden_state[:, 0, :]
203
+
204
+ def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
205
+ """Prepare inputs for encoding.
206
+
207
+ For binary tasks, concatenates all rendered elements.
208
+
209
+ Parameters
210
+ ----------
211
+ items : list[Item]
212
+ Items to encode.
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Encoded representations.
218
+ """
219
+ texts = []
220
+ for item in items:
221
+ # Concatenate all rendered elements
222
+ all_text = " ".join(item.rendered_elements.values())
223
+ texts.append(all_text)
224
+ return self._encode_texts(texts)
225
+
226
+ def _validate_labels(self, labels: list[str]) -> None:
227
+ """Validate that all labels are valid.
228
+
229
+ Parameters
230
+ ----------
231
+ labels : list[str]
232
+ Labels to validate.
233
+
234
+ Raises
235
+ ------
236
+ ValueError
237
+ If any label is not in label_names.
238
+ """
239
+ if self.label_names is None:
240
+ raise ValueError("label_names not initialized")
241
+
242
+ valid_labels = set(self.label_names)
243
+ invalid = [label for label in labels if label not in valid_labels]
244
+ if invalid:
245
+ raise ValueError(
246
+ f"Invalid labels found: {set(invalid)}. "
247
+ f"Labels must be one of {valid_labels}."
248
+ )
249
+
250
+ def _prepare_training_data(
251
+ self,
252
+ items: list[Item],
253
+ labels: list[str],
254
+ participant_ids: list[str],
255
+ validation_items: list[Item] | None,
256
+ validation_labels: list[str] | None,
257
+ ) -> tuple[
258
+ list[Item], list[float], list[str], list[Item] | None, list[float] | None
259
+ ]:
260
+ """Prepare training data for binary model.
261
+
262
+ Parameters
263
+ ----------
264
+ items : list[Item]
265
+ Training items.
266
+ labels : list[str]
267
+ Training labels.
268
+ participant_ids : list[str]
269
+ Normalized participant IDs.
270
+ validation_items : list[Item] | None
271
+ Validation items.
272
+ validation_labels : list[str] | None
273
+ Validation labels.
274
+
275
+ Returns
276
+ -------
277
+ tuple[list[Item], list[float], list[str], list[Item] | None, list[float] | None]
278
+ Prepared items, numeric labels, participant_ids, validation_items,
279
+ numeric validation_labels.
280
+ """
281
+ # Initialize label names
282
+ unique_labels = sorted(set(labels))
283
+ if len(unique_labels) != 2:
284
+ raise ValueError(
285
+ f"Binary classification requires exactly 2 classes, "
286
+ f"got {len(unique_labels)}: {unique_labels}"
287
+ )
288
+ self.label_names = unique_labels
289
+ # Positive class is the second one alphabetically (index 1)
290
+ self.positive_class = unique_labels[1]
291
+
292
+ self._validate_labels(labels)
293
+ self._initialize_classifier()
294
+
295
+ # Convert labels to binary (0/1) floats for HuggingFace Trainer
296
+ # Positive class (second alphabetically) = 1, negative = 0
297
+ y_numeric = [1.0 if label == self.positive_class else 0.0 for label in labels]
298
+
299
+ # Convert validation labels if provided
300
+ val_y_numeric = None
301
+ if validation_items is not None and validation_labels is not None:
302
+ self._validate_labels(validation_labels)
303
+ if len(validation_items) != len(validation_labels):
304
+ raise ValueError(
305
+ f"Number of validation items ({len(validation_items)}) "
306
+ f"must match number of validation labels ({len(validation_labels)})"
307
+ )
308
+ val_y_numeric = [
309
+ 1.0 if label == self.positive_class else 0.0
310
+ for label in validation_labels
311
+ ]
312
+
313
+ return items, y_numeric, participant_ids, validation_items, val_y_numeric
314
+
315
+ def _initialize_random_effects(self, n_classes: int) -> None:
316
+ """Initialize random effects manager.
317
+
318
+ Parameters
319
+ ----------
320
+ n_classes : int
321
+ Number of classes (1 for binary).
322
+ """
323
+ self.random_effects = RandomEffectsManager(
324
+ self.config.mixed_effects, n_classes=n_classes
325
+ )
326
+
327
+ def _do_training(
328
+ self,
329
+ items: list[Item],
330
+ labels_numeric: list[float],
331
+ participant_ids: list[str],
332
+ validation_items: list[Item] | None,
333
+ validation_labels_numeric: list[float] | None,
334
+ ) -> dict[str, float]:
335
+ """Perform binary model training.
336
+
337
+ Parameters
338
+ ----------
339
+ items : list[Item]
340
+ Training items.
341
+ labels_numeric : list[float]
342
+ Numeric labels (0.0 or 1.0).
343
+ participant_ids : list[str]
344
+ Participant IDs.
345
+ validation_items : list[Item] | None
346
+ Validation items.
347
+ validation_labels_numeric : list[float] | None
348
+ Numeric validation labels.
349
+
350
+ Returns
351
+ -------
352
+ dict[str, float]
353
+ Training metrics.
354
+ """
355
+ # Convert validation_labels_numeric back to string labels for validation metrics
356
+ validation_labels = None
357
+ if validation_items is not None and validation_labels_numeric is not None:
358
+ validation_labels = [
359
+ self.positive_class if label == 1.0 else self.label_names[0]
360
+ for label in validation_labels_numeric
361
+ ]
362
+
363
+ # Use HuggingFace Trainer for fixed and random_intercepts modes
364
+ # random_slopes requires custom loop due to per-participant heads
365
+ use_huggingface_trainer = self.config.mixed_effects.mode in (
366
+ "fixed",
367
+ "random_intercepts",
368
+ )
369
+
370
+ if use_huggingface_trainer:
371
+ metrics = self._train_with_huggingface_trainer(
372
+ items,
373
+ labels_numeric,
374
+ participant_ids,
375
+ validation_items,
376
+ validation_labels,
377
+ )
378
+ else:
379
+ # Use custom training loop for random_slopes
380
+ metrics = self._train_with_custom_loop(
381
+ items,
382
+ labels_numeric,
383
+ participant_ids,
384
+ validation_items,
385
+ validation_labels,
386
+ )
387
+
388
+ # Add validation accuracy if validation data provided
389
+ if validation_items is not None and validation_labels is not None:
390
+ # Validation with placeholder participant_ids for mixed effects
391
+ # Use _do_predict directly since we're in training
392
+ if self.config.mixed_effects.mode == "fixed":
393
+ val_participant_ids = ["_fixed_"] * len(validation_items)
394
+ else:
395
+ val_participant_ids = ["_validation_"] * len(validation_items)
396
+ val_predictions = self._do_predict(validation_items, val_participant_ids)
397
+ val_pred_labels = [p.predicted_class for p in val_predictions]
398
+ val_acc = sum(
399
+ pred == true
400
+ for pred, true in zip(val_pred_labels, validation_labels, strict=True)
401
+ ) / len(validation_labels)
402
+ metrics["val_accuracy"] = val_acc
403
+
404
+ return metrics
405
+
406
+ def _train_with_huggingface_trainer(
407
+ self,
408
+ items: list[Item],
409
+ y_numeric: list[float],
410
+ participant_ids: list[str],
411
+ validation_items: list[Item] | None,
412
+ validation_labels: list[str] | None,
413
+ ) -> dict[str, float]:
414
+ """Train using HuggingFace Trainer with mixed effects support.
415
+
416
+ Parameters
417
+ ----------
418
+ items : list[Item]
419
+ Training items.
420
+ y_numeric : list[float]
421
+ Numeric labels (0.0 or 1.0).
422
+ participant_ids : list[str]
423
+ Participant IDs.
424
+ validation_items : list[Item] | None
425
+ Validation items.
426
+ validation_labels : list[str] | None
427
+ Validation labels.
428
+
429
+ Returns
430
+ -------
431
+ dict[str, float]
432
+ Training metrics.
433
+ """
434
+ # Convert items to HuggingFace Dataset
435
+ train_dataset = items_to_dataset(
436
+ items=items,
437
+ labels=y_numeric,
438
+ participant_ids=participant_ids,
439
+ tokenizer=self.tokenizer,
440
+ max_length=self.config.max_length,
441
+ )
442
+
443
+ # Create validation dataset if provided
444
+ eval_dataset = None
445
+ if validation_items is not None and validation_labels is not None:
446
+ val_y_numeric = [
447
+ 1.0 if label == self.positive_class else 0.0
448
+ for label in validation_labels
449
+ ]
450
+ eval_dataset = items_to_dataset(
451
+ items=validation_items,
452
+ labels=val_y_numeric,
453
+ participant_ids=["_validation_"] * len(validation_items),
454
+ tokenizer=self.tokenizer,
455
+ max_length=self.config.max_length,
456
+ )
457
+
458
+ # Create wrapper model for Trainer
459
+ wrapped_model = EncoderClassifierWrapper(
460
+ encoder=self.encoder, classifier_head=self.classifier_head
461
+ )
462
+
463
+ # Create data collator
464
+ data_collator = MixedEffectsDataCollator(tokenizer=self.tokenizer)
465
+
466
+ # Create training arguments with checkpointing
467
+ with tempfile.TemporaryDirectory() as tmpdir:
468
+ checkpoint_dir = Path(tmpdir) / "checkpoints"
469
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
470
+
471
+ training_args = TrainingArguments(
472
+ output_dir=str(checkpoint_dir),
473
+ num_train_epochs=self.config.num_epochs,
474
+ per_device_train_batch_size=self.config.batch_size,
475
+ per_device_eval_batch_size=self.config.batch_size,
476
+ learning_rate=self.config.learning_rate,
477
+ logging_steps=10,
478
+ eval_strategy="epoch" if eval_dataset is not None else "no",
479
+ save_strategy="epoch", # Save checkpoints every epoch
480
+ save_total_limit=1, # Keep only the latest checkpoint
481
+ load_best_model_at_end=False, # Don't auto-load best
482
+ report_to="none", # Disable wandb/tensorboard
483
+ remove_unused_columns=False, # Keep participant_id
484
+ use_cpu=self.config.device == "cpu", # Explicitly use CPU if specified
485
+ )
486
+
487
+ # Import here to avoid circular import
488
+ from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
489
+ MixedEffectsTrainer,
490
+ )
491
+
492
+ # Create trainer
493
+ trainer = MixedEffectsTrainer(
494
+ model=wrapped_model,
495
+ args=training_args,
496
+ train_dataset=train_dataset,
497
+ eval_dataset=eval_dataset,
498
+ data_collator=data_collator,
499
+ tokenizer=self.tokenizer,
500
+ random_effects_manager=self.random_effects,
501
+ compute_metrics=compute_binary_metrics,
502
+ )
503
+
504
+ # Train (checkpoints are saved automatically by Trainer)
505
+ train_result = trainer.train()
506
+
507
+ # Get training metrics using evaluate (Trainer computes metrics during eval)
508
+ train_metrics = trainer.evaluate(eval_dataset=train_dataset)
509
+ # Trainer prefixes eval metrics with "eval_"
510
+ metrics: dict[str, float] = {
511
+ "train_loss": float(train_result.training_loss),
512
+ "train_accuracy": train_metrics.get("eval_accuracy", 0.0),
513
+ "train_precision": train_metrics.get("eval_precision", 0.0),
514
+ "train_recall": train_metrics.get("eval_recall", 0.0),
515
+ "train_f1": train_metrics.get("eval_f1", 0.0),
516
+ }
517
+
518
+ # Get validation metrics if eval_dataset was provided
519
+ if eval_dataset is not None:
520
+ val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
521
+ metrics.update(
522
+ {
523
+ "val_accuracy": val_metrics.get("eval_accuracy", 0.0),
524
+ "val_precision": val_metrics.get("eval_precision", 0.0),
525
+ "val_recall": val_metrics.get("eval_recall", 0.0),
526
+ "val_f1": val_metrics.get("eval_f1", 0.0),
527
+ }
528
+ )
529
+
530
+ # Estimate variance components
531
+ if self.config.mixed_effects.estimate_variance_components:
532
+ var_comps = self.random_effects.estimate_variance_components()
533
+ if var_comps:
534
+ var_comp = var_comps.get("mu") or var_comps.get("slopes")
535
+ if var_comp:
536
+ self.variance_history.append(var_comp)
537
+ metrics["participant_variance"] = var_comp.variance
538
+ metrics["n_participants"] = var_comp.n_groups
539
+
540
+ # Validation metrics (already computed by Trainer if eval_dataset provided)
541
+ if eval_dataset is not None:
542
+ val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
543
+ metrics.update(
544
+ {
545
+ "val_accuracy": val_metrics.get("eval_accuracy", 0.0),
546
+ "val_precision": val_metrics.get("eval_precision", 0.0),
547
+ "val_recall": val_metrics.get("eval_recall", 0.0),
548
+ "val_f1": val_metrics.get("eval_f1", 0.0),
549
+ }
550
+ )
551
+
552
+ return metrics
553
+
554
+ def _train_with_custom_loop(
555
+ self,
556
+ items: list[Item],
557
+ y_numeric: list[float],
558
+ participant_ids: list[str],
559
+ validation_items: list[Item] | None,
560
+ validation_labels: list[str] | None,
561
+ ) -> dict[str, float]:
562
+ """Train using custom training loop (for random_slopes mode).
563
+
564
+ Parameters
565
+ ----------
566
+ items : list[Item]
567
+ Training items.
568
+ y_numeric : list[float]
569
+ Numeric labels (0.0 or 1.0).
570
+ participant_ids : list[str]
571
+ Participant IDs.
572
+ validation_items : list[Item] | None
573
+ Validation items.
574
+ validation_labels : list[str] | None
575
+ Validation labels.
576
+
577
+ Returns
578
+ -------
579
+ dict[str, float]
580
+ Training metrics.
581
+ """
582
+ # Convert to tensor
583
+ y = torch.tensor(y_numeric, dtype=torch.float, device=self.config.device)
584
+
585
+ # Build optimizer parameters
586
+ params_to_optimize = list(self.encoder.parameters()) + list(
587
+ self.classifier_head.parameters()
588
+ )
589
+
590
+ # Add random effects parameters (for random_slopes)
591
+ if self.config.mixed_effects.mode == "random_slopes":
592
+ for head in self.random_effects.slopes.values():
593
+ params_to_optimize.extend(head.parameters())
594
+
595
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
596
+ criterion = nn.BCEWithLogitsLoss()
597
+
598
+ self.encoder.train()
599
+ self.classifier_head.train()
600
+
601
+ for _epoch in range(self.config.num_epochs):
602
+ n_batches = (
603
+ len(items) + self.config.batch_size - 1
604
+ ) // self.config.batch_size
605
+ epoch_loss = 0.0
606
+ epoch_correct = 0
607
+
608
+ for i in range(n_batches):
609
+ start_idx = i * self.config.batch_size
610
+ end_idx = min(start_idx + self.config.batch_size, len(items))
611
+
612
+ batch_items = items[start_idx:end_idx]
613
+ batch_labels = y[start_idx:end_idx]
614
+ batch_participant_ids = participant_ids[start_idx:end_idx]
615
+
616
+ embeddings = self._prepare_inputs(batch_items)
617
+
618
+ # Random slopes: per-participant head
619
+ logits_list = []
620
+ for j, pid in enumerate(batch_participant_ids):
621
+ participant_head = self.random_effects.get_slopes(
622
+ pid,
623
+ fixed_head=self.classifier_head,
624
+ create_if_missing=True,
625
+ )
626
+ logits_j = participant_head(embeddings[j : j + 1]).squeeze()
627
+ logits_list.append(logits_j)
628
+ logits = torch.stack(logits_list)
629
+
630
+ # Data loss + prior regularization
631
+ loss_bce = criterion(logits, batch_labels)
632
+ loss_prior = self.random_effects.compute_prior_loss()
633
+ loss = loss_bce + loss_prior
634
+
635
+ optimizer.zero_grad()
636
+ loss.backward()
637
+ optimizer.step()
638
+
639
+ epoch_loss += loss.item()
640
+ predictions = (torch.sigmoid(logits) > 0.5).float()
641
+ epoch_correct += (predictions == batch_labels).sum().item()
642
+
643
+ epoch_acc = epoch_correct / len(items)
644
+ epoch_loss = epoch_loss / n_batches
645
+
646
+ metrics: dict[str, float] = {
647
+ "train_accuracy": epoch_acc,
648
+ "train_loss": epoch_loss,
649
+ }
650
+
651
+ # Estimate variance components
652
+ if self.config.mixed_effects.estimate_variance_components:
653
+ var_comps = self.random_effects.estimate_variance_components()
654
+ if var_comps:
655
+ var_comp = var_comps.get("mu") or var_comps.get("slopes")
656
+ if var_comp:
657
+ self.variance_history.append(var_comp)
658
+ metrics["participant_variance"] = var_comp.variance
659
+ metrics["n_participants"] = var_comp.n_groups
660
+
661
+ # Validation
662
+ if validation_items is not None and validation_labels is not None:
663
+ val_predictions = self.predict(
664
+ validation_items,
665
+ participant_ids=["_validation_"] * len(validation_items),
666
+ )
667
+ val_pred_labels = [p.predicted_class for p in val_predictions]
668
+ val_acc = sum(
669
+ pred == true
670
+ for pred, true in zip(val_pred_labels, validation_labels, strict=True)
671
+ ) / len(validation_labels)
672
+ metrics["val_accuracy"] = val_acc
673
+
674
+ return metrics
675
+
676
+ def _do_predict(
677
+ self, items: list[Item], participant_ids: list[str]
678
+ ) -> list[ModelPrediction]:
679
+ """Perform binary model prediction.
680
+
681
+ Parameters
682
+ ----------
683
+ items : list[Item]
684
+ Items to predict.
685
+ participant_ids : list[str]
686
+ Normalized participant IDs.
687
+
688
+ Returns
689
+ -------
690
+ list[ModelPrediction]
691
+ Predictions.
692
+ """
693
+ self.encoder.eval()
694
+ self.classifier_head.eval()
695
+
696
+ with torch.no_grad():
697
+ embeddings = self._prepare_inputs(items)
698
+
699
+ # Forward pass depends on mixed effects mode
700
+ if self.config.mixed_effects.mode == "fixed":
701
+ logits = self.classifier_head(embeddings).squeeze(1) # (n_items,)
702
+
703
+ elif self.config.mixed_effects.mode == "random_intercepts":
704
+ logits = self.classifier_head(embeddings).squeeze(1) # (n_items,)
705
+ for i, pid in enumerate(participant_ids):
706
+ # Unknown participants: use prior mean (zero bias)
707
+ bias = self.random_effects.get_intercepts(
708
+ pid,
709
+ n_classes=self.num_classes,
710
+ param_name="mu",
711
+ create_if_missing=False,
712
+ )
713
+ logits[i] = logits[i] + bias.item()
714
+
715
+ elif self.config.mixed_effects.mode == "random_slopes":
716
+ logits_list = []
717
+ for i, pid in enumerate(participant_ids):
718
+ # Unknown participants: use fixed head
719
+ participant_head = self.random_effects.get_slopes(
720
+ pid, fixed_head=self.classifier_head, create_if_missing=False
721
+ )
722
+ logits_i = participant_head(embeddings[i : i + 1]).squeeze()
723
+ logits_list.append(logits_i)
724
+ logits = torch.stack(logits_list)
725
+
726
+ # Compute probabilities using sigmoid
727
+ proba_positive = torch.sigmoid(logits).cpu().numpy() # P(y=1)
728
+ pred_is_positive = proba_positive > 0.5
729
+
730
+ predictions = []
731
+ for i, item in enumerate(items):
732
+ # Determine predicted class
733
+ if pred_is_positive[i]:
734
+ pred_label = self.positive_class
735
+ else:
736
+ pred_label = self.label_names[0]
737
+
738
+ # Build probability dict: {negative_class: p0, positive_class: p1}
739
+ p1 = float(proba_positive[i])
740
+ p0 = 1.0 - p1
741
+ prob_dict = {
742
+ self.label_names[0]: p0, # Negative class (first alphabetically)
743
+ self.positive_class: p1, # Positive class (second alphabetically)
744
+ }
745
+
746
+ predictions.append(
747
+ ModelPrediction(
748
+ item_id=str(item.id),
749
+ probabilities=prob_dict,
750
+ predicted_class=pred_label,
751
+ confidence=max(p0, p1),
752
+ )
753
+ )
754
+
755
+ return predictions
756
+
757
+ def _do_predict_proba(
758
+ self, items: list[Item], participant_ids: list[str]
759
+ ) -> np.ndarray:
760
+ """Perform binary model probability prediction.
761
+
762
+ Parameters
763
+ ----------
764
+ items : list[Item]
765
+ Items to predict.
766
+ participant_ids : list[str]
767
+ Normalized participant IDs.
768
+
769
+ Returns
770
+ -------
771
+ np.ndarray
772
+ Probability array of shape (n_items, 2).
773
+ """
774
+ self.encoder.eval()
775
+ self.classifier_head.eval()
776
+
777
+ with torch.no_grad():
778
+ embeddings = self._prepare_inputs(items)
779
+
780
+ # Forward pass depends on mixed effects mode
781
+ if self.config.mixed_effects.mode == "fixed":
782
+ logits = self.classifier_head(embeddings).squeeze(1)
783
+
784
+ elif self.config.mixed_effects.mode == "random_intercepts":
785
+ logits = self.classifier_head(embeddings).squeeze(1)
786
+ for i, pid in enumerate(participant_ids):
787
+ bias = self.random_effects.get_intercepts(
788
+ pid,
789
+ n_classes=self.num_classes,
790
+ param_name="mu",
791
+ create_if_missing=False,
792
+ )
793
+ logits[i] = logits[i] + bias.item()
794
+
795
+ elif self.config.mixed_effects.mode == "random_slopes":
796
+ logits_list = []
797
+ for i, pid in enumerate(participant_ids):
798
+ participant_head = self.random_effects.get_slopes(
799
+ pid, fixed_head=self.classifier_head, create_if_missing=False
800
+ )
801
+ logits_i = participant_head(embeddings[i : i + 1]).squeeze()
802
+ logits_list.append(logits_i)
803
+ logits = torch.stack(logits_list)
804
+
805
+ # Compute probabilities using sigmoid
806
+ proba_positive = torch.sigmoid(logits).cpu().numpy() # P(y=1)
807
+
808
+ # Return (n_items, 2) array: [P(negative), P(positive)]
809
+ proba = np.stack([1.0 - proba_positive, proba_positive], axis=1)
810
+
811
+ return proba
812
+
813
+ def _get_save_state(self) -> dict[str, object]:
814
+ """Get model-specific state to save.
815
+
816
+ Returns
817
+ -------
818
+ dict[str, object]
819
+ State dictionary.
820
+ """
821
+ return {
822
+ "num_classes": self.num_classes,
823
+ "label_names": self.label_names,
824
+ "positive_class": self.positive_class,
825
+ }
826
+
827
+ def _save_model_components(self, save_path: Path) -> None:
828
+ """Save model-specific components.
829
+
830
+ Parameters
831
+ ----------
832
+ save_path : Path
833
+ Directory to save to.
834
+ """
835
+ self.encoder.save_pretrained(save_path / "encoder")
836
+ self.tokenizer.save_pretrained(save_path / "encoder")
837
+
838
+ torch.save(
839
+ self.classifier_head.state_dict(),
840
+ save_path / "classifier_head.pt",
841
+ )
842
+
843
+ def _restore_training_state(self, config_dict: dict[str, object]) -> None:
844
+ """Restore model-specific training state.
845
+
846
+ Parameters
847
+ ----------
848
+ config_dict : dict[str, object]
849
+ Configuration dictionary with training state.
850
+ """
851
+ self.num_classes = config_dict.pop("num_classes")
852
+ self.label_names = config_dict.pop("label_names")
853
+ self.positive_class = config_dict.pop("positive_class")
854
+
855
+ def _load_model_components(self, load_path: Path) -> None:
856
+ """Load model-specific components.
857
+
858
+ Parameters
859
+ ----------
860
+ load_path : Path
861
+ Directory to load from.
862
+ """
863
+ # Load config.json to reconstruct config
864
+ with open(load_path / "config.json") as f:
865
+ import json # noqa: PLC0415
866
+
867
+ config_dict = json.load(f)
868
+
869
+ # Reconstruct MixedEffectsConfig if needed
870
+ if "mixed_effects" in config_dict and isinstance(
871
+ config_dict["mixed_effects"], dict
872
+ ):
873
+ from bead.active_learning.config import MixedEffectsConfig # noqa: PLC0415
874
+
875
+ config_dict["mixed_effects"] = MixedEffectsConfig(
876
+ **config_dict["mixed_effects"]
877
+ )
878
+
879
+ self.config = BinaryModelConfig(**config_dict)
880
+
881
+ self.encoder = AutoModel.from_pretrained(load_path / "encoder")
882
+ self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
883
+
884
+ self._initialize_classifier()
885
+ self.classifier_head.load_state_dict(
886
+ torch.load(
887
+ load_path / "classifier_head.pt", map_location=self.config.device
888
+ )
889
+ )
890
+ self.classifier_head.to(self.config.device)
891
+
892
+ def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
893
+ """Get fixed head for random effects loading.
894
+
895
+ Returns
896
+ -------
897
+ nn.Module | None
898
+ Fixed head module.
899
+ """
900
+ return self.classifier_head
901
+
902
+ def _get_n_classes_for_random_effects(self) -> int:
903
+ """Get number of classes for random effects initialization.
904
+
905
+ Returns
906
+ -------
907
+ int
908
+ Number of classes (1 for binary).
909
+ """
910
+ return self.num_classes