bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,943 @@
1
+ """Model for categorical tasks (unordered N-class classification)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import AutoModel, AutoTokenizer, TrainingArguments
12
+
13
+ from bead.active_learning.config import VarianceComponents
14
+ from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
15
+ from bead.active_learning.models.random_effects import RandomEffectsManager
16
+ from bead.active_learning.trainers.data_collator import MixedEffectsDataCollator
17
+ from bead.active_learning.trainers.dataset_utils import items_to_dataset
18
+ from bead.active_learning.trainers.metrics import compute_multiclass_metrics
19
+ from bead.active_learning.trainers.model_wrapper import EncoderClassifierWrapper
20
+ from bead.config.active_learning import CategoricalModelConfig
21
+ from bead.items.item import Item
22
+ from bead.items.item_template import ItemTemplate, TaskType
23
+
24
+ __all__ = ["CategoricalModel"]
25
+
26
+
27
+ class CategoricalModel(ActiveLearningModel):
28
+ """Model for categorical tasks with N unordered categories.
29
+
30
+ Supports N-class classification (N ≥ 2) using any HuggingFace transformer
31
+ model. Provides two encoding strategies: single encoder (concatenate
32
+ categories) or dual encoder (separate embeddings).
33
+
34
+ Parameters
35
+ ----------
36
+ config : CategoricalModelConfig
37
+ Configuration object containing all model parameters.
38
+
39
+ Attributes
40
+ ----------
41
+ config : CategoricalModelConfig
42
+ Model configuration.
43
+ tokenizer : AutoTokenizer
44
+ Transformer tokenizer.
45
+ encoder : AutoModel
46
+ Transformer encoder model.
47
+ classifier_head : nn.Sequential
48
+ Classification head (fixed effects head).
49
+ num_classes : int | None
50
+ Number of classes (inferred from training data).
51
+ category_names : list[str] | None
52
+ Category names (e.g., ["entailment", "neutral", "contradiction"]).
53
+ random_effects : RandomEffectsManager
54
+ Manager for participant-level random effects.
55
+ variance_history : list[VarianceComponents]
56
+ Variance component estimates over training (for diagnostics).
57
+ _is_fitted : bool
58
+ Whether model has been trained.
59
+
60
+ Examples
61
+ --------
62
+ >>> from uuid import uuid4
63
+ >>> from bead.items.item import Item
64
+ >>> from bead.config.active_learning import CategoricalModelConfig
65
+ >>> items = [
66
+ ... Item(
67
+ ... item_template_id=uuid4(),
68
+ ... rendered_elements={"premise": "sent A", "hypothesis": "sent B"}
69
+ ... )
70
+ ... for _ in range(10)
71
+ ... ]
72
+ >>> labels = ["entailment"] * 5 + ["contradiction"] * 5
73
+ >>> config = CategoricalModelConfig( # doctest: +SKIP
74
+ ... num_epochs=1, batch_size=2, device="cpu"
75
+ ... )
76
+ >>> model = CategoricalModel(config=config) # doctest: +SKIP
77
+ >>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
78
+ >>> predictions = model.predict(items[:3], participant_ids=None) # doctest: +SKIP
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ config: CategoricalModelConfig | None = None,
84
+ ) -> None:
85
+ """Initialize categorical model.
86
+
87
+ Parameters
88
+ ----------
89
+ config : CategoricalModelConfig | None
90
+ Configuration object. If None, uses default configuration.
91
+ """
92
+ self.config = config or CategoricalModelConfig()
93
+
94
+ # Validate mixed_effects configuration
95
+ super().__init__(self.config)
96
+
97
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
98
+ self.encoder = AutoModel.from_pretrained(self.config.model_name)
99
+
100
+ self.num_classes: int | None = None
101
+ self.category_names: list[str] | None = None
102
+ self.classifier_head: nn.Sequential | None = None
103
+ self._is_fitted = False
104
+
105
+ # Initialize random effects manager
106
+ self.random_effects: RandomEffectsManager | None = None
107
+ self.variance_history: list[VarianceComponents] = []
108
+
109
+ self.encoder.to(self.config.device)
110
+
111
+ @property
112
+ def supported_task_types(self) -> list[TaskType]:
113
+ """Get supported task types.
114
+
115
+ Returns
116
+ -------
117
+ list[TaskType]
118
+ List containing "categorical".
119
+ """
120
+ return ["categorical"]
121
+
122
+ def validate_item_compatibility(
123
+ self, item: Item, item_template: ItemTemplate
124
+ ) -> None:
125
+ """Validate item is compatible with categorical model.
126
+
127
+ Parameters
128
+ ----------
129
+ item : Item
130
+ Item to validate.
131
+ item_template : ItemTemplate
132
+ Template the item was constructed from.
133
+
134
+ Raises
135
+ ------
136
+ ValueError
137
+ If task_type is not "categorical".
138
+ """
139
+ if item_template.task_type != "categorical":
140
+ raise ValueError(
141
+ f"Expected task_type 'categorical', got '{item_template.task_type}'"
142
+ )
143
+
144
+ def _initialize_classifier(self, num_classes: int) -> None:
145
+ """Initialize classification head for given number of classes.
146
+
147
+ Parameters
148
+ ----------
149
+ num_classes : int
150
+ Number of output classes.
151
+ """
152
+ hidden_size = self.encoder.config.hidden_size
153
+
154
+ if self.config.encoder_mode == "dual_encoder":
155
+ input_size = hidden_size * num_classes
156
+ else:
157
+ input_size = hidden_size
158
+
159
+ self.classifier_head = nn.Sequential(
160
+ nn.Linear(input_size, 256),
161
+ nn.ReLU(),
162
+ nn.Dropout(0.1),
163
+ nn.Linear(256, num_classes),
164
+ )
165
+ self.classifier_head.to(self.config.device)
166
+
167
+ def _encode_single(self, texts: list[str]) -> torch.Tensor:
168
+ """Encode texts using single encoder strategy.
169
+
170
+ Concatenates all category texts with [SEP] tokens and encodes once.
171
+
172
+ Parameters
173
+ ----------
174
+ texts : list[str]
175
+ List of concatenated category texts for each item.
176
+
177
+ Returns
178
+ -------
179
+ torch.Tensor
180
+ Encoded representations of shape (batch_size, hidden_size).
181
+ """
182
+ encodings = self.tokenizer(
183
+ texts,
184
+ padding=True,
185
+ truncation=True,
186
+ max_length=self.config.max_length,
187
+ return_tensors="pt",
188
+ )
189
+ encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
190
+
191
+ outputs = self.encoder(**encodings)
192
+ return outputs.last_hidden_state[:, 0, :]
193
+
194
+ def _encode_dual(self, categories_per_item: list[list[str]]) -> torch.Tensor:
195
+ """Encode texts using dual encoder strategy.
196
+
197
+ Encodes each category separately and concatenates embeddings.
198
+
199
+ Parameters
200
+ ----------
201
+ categories_per_item : list[list[str]]
202
+ List of category lists. Each inner list contains category texts
203
+ for one item.
204
+
205
+ Returns
206
+ -------
207
+ torch.Tensor
208
+ Concatenated encodings of shape (batch_size, hidden_size * num_categories).
209
+ """
210
+ all_embeddings = []
211
+
212
+ for categories in categories_per_item:
213
+ category_embeddings = []
214
+ for category_text in categories:
215
+ encodings = self.tokenizer(
216
+ [category_text],
217
+ padding=True,
218
+ truncation=True,
219
+ max_length=self.config.max_length,
220
+ return_tensors="pt",
221
+ )
222
+ encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
223
+
224
+ outputs = self.encoder(**encodings)
225
+ cls_embedding = outputs.last_hidden_state[0, 0, :]
226
+ category_embeddings.append(cls_embedding)
227
+
228
+ concatenated = torch.cat(category_embeddings, dim=0)
229
+ all_embeddings.append(concatenated)
230
+
231
+ return torch.stack(all_embeddings)
232
+
233
+ def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
234
+ """Prepare inputs for encoding based on encoder mode.
235
+
236
+ For categorical tasks, concatenates all rendered elements.
237
+
238
+ Parameters
239
+ ----------
240
+ items : list[Item]
241
+ Items to encode.
242
+
243
+ Returns
244
+ -------
245
+ torch.Tensor
246
+ Encoded representations.
247
+ """
248
+ if self.category_names is None:
249
+ raise ValueError("Model not initialized. Call train() first.")
250
+
251
+ if self.config.encoder_mode == "single_encoder":
252
+ texts = []
253
+ for item in items:
254
+ # Concatenate all rendered elements
255
+ all_text = " ".join(item.rendered_elements.values())
256
+ texts.append(all_text)
257
+ return self._encode_single(texts)
258
+ else:
259
+ categories_per_item = []
260
+ for item in items:
261
+ category_texts = list(item.rendered_elements.values())
262
+ categories_per_item.append(category_texts)
263
+ return self._encode_dual(categories_per_item)
264
+
265
+ def _validate_labels(self, labels: list[str]) -> None:
266
+ """Validate that all labels are valid category names.
267
+
268
+ Parameters
269
+ ----------
270
+ labels : list[str]
271
+ Labels to validate.
272
+
273
+ Raises
274
+ ------
275
+ ValueError
276
+ If any label is not in category_names.
277
+ """
278
+ if self.category_names is None:
279
+ raise ValueError("category_names not initialized")
280
+
281
+ valid_labels = set(self.category_names)
282
+ invalid = [label for label in labels if label not in valid_labels]
283
+ if invalid:
284
+ raise ValueError(
285
+ f"Invalid labels found: {set(invalid)}. "
286
+ f"Labels must be one of {valid_labels}."
287
+ )
288
+
289
+ def _prepare_training_data(
290
+ self,
291
+ items: list[Item],
292
+ labels: list[str],
293
+ participant_ids: list[str],
294
+ validation_items: list[Item] | None,
295
+ validation_labels: list[str] | None,
296
+ ) -> tuple[list[Item], list[int], list[str], list[Item] | None, list[int] | None]:
297
+ """Prepare training data for categorical model.
298
+
299
+ Parameters
300
+ ----------
301
+ items : list[Item]
302
+ Training items.
303
+ labels : list[str]
304
+ Training labels.
305
+ participant_ids : list[str]
306
+ Normalized participant IDs.
307
+ validation_items : list[Item] | None
308
+ Validation items.
309
+ validation_labels : list[str] | None
310
+ Validation labels.
311
+
312
+ Returns
313
+ -------
314
+ tuple[list[Item], list[int], list[str], list[Item] | None, list[int] | None]
315
+ Prepared items, numeric labels, participant_ids, validation_items,
316
+ numeric validation_labels.
317
+ """
318
+ unique_labels = sorted(set(labels))
319
+ self.num_classes = len(unique_labels)
320
+ self.category_names = unique_labels
321
+
322
+ self._validate_labels(labels)
323
+ self._initialize_classifier(self.num_classes)
324
+
325
+ label_to_idx = {label: idx for idx, label in enumerate(self.category_names)}
326
+ y_numeric = [label_to_idx[label] for label in labels]
327
+
328
+ # Convert validation labels if provided
329
+ val_y_numeric = None
330
+ if validation_items is not None and validation_labels is not None:
331
+ self._validate_labels(validation_labels)
332
+ if len(validation_items) != len(validation_labels):
333
+ raise ValueError(
334
+ f"Number of validation items ({len(validation_items)}) "
335
+ f"must match number of validation labels ({len(validation_labels)})"
336
+ )
337
+ val_y_numeric = [label_to_idx[label] for label in validation_labels]
338
+
339
+ return items, y_numeric, participant_ids, validation_items, val_y_numeric
340
+
341
+ def _initialize_random_effects(self, n_classes: int) -> None:
342
+ """Initialize random effects manager.
343
+
344
+ Parameters
345
+ ----------
346
+ n_classes : int
347
+ Number of classes.
348
+ """
349
+ self.random_effects = RandomEffectsManager(
350
+ self.config.mixed_effects, n_classes=n_classes
351
+ )
352
+
353
+ def _do_training(
354
+ self,
355
+ items: list[Item],
356
+ labels_numeric: list[int],
357
+ participant_ids: list[str],
358
+ validation_items: list[Item] | None,
359
+ validation_labels_numeric: list[int] | None,
360
+ ) -> dict[str, float]:
361
+ """Perform categorical model training.
362
+
363
+ Parameters
364
+ ----------
365
+ items : list[Item]
366
+ Training items.
367
+ labels_numeric : list[int]
368
+ Numeric labels (class indices).
369
+ participant_ids : list[str]
370
+ Participant IDs.
371
+ validation_items : list[Item] | None
372
+ Validation items.
373
+ validation_labels_numeric : list[int] | None
374
+ Numeric validation labels.
375
+
376
+ Returns
377
+ -------
378
+ dict[str, float]
379
+ Training metrics.
380
+ """
381
+ # Convert validation_labels_numeric back to string labels for validation metrics
382
+ validation_labels = None
383
+ if validation_items is not None and validation_labels_numeric is not None:
384
+ validation_labels = [
385
+ self.category_names[label_idx]
386
+ for label_idx in validation_labels_numeric
387
+ ]
388
+
389
+ # Use HuggingFace Trainer for fixed and random_intercepts modes
390
+ if self.config.mixed_effects.mode in ("fixed", "random_intercepts"):
391
+ metrics = self._train_with_huggingface_trainer(
392
+ items=items,
393
+ y_numeric=labels_numeric,
394
+ participant_ids=participant_ids,
395
+ validation_items=validation_items,
396
+ validation_labels=validation_labels,
397
+ )
398
+ else:
399
+ # Use custom loop for random_slopes mode
400
+ metrics = self._train_with_custom_loop(
401
+ items=items,
402
+ y_numeric=labels_numeric,
403
+ participant_ids=participant_ids,
404
+ validation_items=validation_items,
405
+ validation_labels=validation_labels,
406
+ )
407
+
408
+ # Add validation accuracy if validation data provided and not already computed
409
+ if (
410
+ validation_items is not None
411
+ and validation_labels is not None
412
+ and "val_accuracy" not in metrics
413
+ ):
414
+ # Validation with placeholder participant_ids for mixed effects
415
+ if self.config.mixed_effects.mode == "fixed":
416
+ val_participant_ids = ["_fixed_"] * len(validation_items)
417
+ else:
418
+ val_participant_ids = ["_validation_"] * len(validation_items)
419
+ val_predictions = self._do_predict(validation_items, val_participant_ids)
420
+ val_pred_labels = [p.predicted_class for p in val_predictions]
421
+ val_acc = sum(
422
+ pred == true
423
+ for pred, true in zip(val_pred_labels, validation_labels, strict=True)
424
+ ) / len(validation_labels)
425
+ metrics["val_accuracy"] = val_acc
426
+
427
+ return metrics
428
+
429
+ def _train_with_huggingface_trainer(
430
+ self,
431
+ items: list[Item],
432
+ y_numeric: list[int],
433
+ participant_ids: list[str],
434
+ validation_items: list[Item] | None,
435
+ validation_labels: list[str] | None,
436
+ ) -> dict[str, float]:
437
+ """Train using HuggingFace Trainer with mixed effects support.
438
+
439
+ Parameters
440
+ ----------
441
+ items : list[Item]
442
+ Training items.
443
+ y_numeric : list[int]
444
+ Numeric labels (class indices).
445
+ participant_ids : list[str]
446
+ Participant IDs.
447
+ validation_items : list[Item] | None
448
+ Validation items.
449
+ validation_labels : list[str] | None
450
+ Validation labels.
451
+
452
+ Returns
453
+ -------
454
+ dict[str, float]
455
+ Training metrics.
456
+ """
457
+ # Convert items to HuggingFace Dataset
458
+ train_dataset = items_to_dataset(
459
+ items=items,
460
+ labels=y_numeric,
461
+ participant_ids=participant_ids,
462
+ tokenizer=self.tokenizer,
463
+ max_length=self.config.max_length,
464
+ )
465
+
466
+ # Create validation dataset if provided
467
+ eval_dataset = None
468
+ if validation_items is not None and validation_labels is not None:
469
+ label_to_idx = {label: idx for idx, label in enumerate(self.category_names)}
470
+ val_y_numeric = [label_to_idx[label] for label in validation_labels]
471
+ val_participant_ids = (
472
+ ["_validation_"] * len(validation_items)
473
+ if self.config.mixed_effects.mode != "fixed"
474
+ else ["_fixed_"] * len(validation_items)
475
+ )
476
+ eval_dataset = items_to_dataset(
477
+ items=validation_items,
478
+ labels=val_y_numeric,
479
+ participant_ids=val_participant_ids,
480
+ tokenizer=self.tokenizer,
481
+ max_length=self.config.max_length,
482
+ )
483
+
484
+ # Create wrapper model for Trainer
485
+ wrapped_model = EncoderClassifierWrapper(
486
+ encoder=self.encoder, classifier_head=self.classifier_head
487
+ )
488
+
489
+ # Create data collator
490
+ data_collator = MixedEffectsDataCollator(tokenizer=self.tokenizer)
491
+
492
+ # Create metrics computation function
493
+ def compute_metrics_fn(eval_pred: object) -> dict[str, float]:
494
+ return compute_multiclass_metrics(eval_pred, num_labels=self.num_classes)
495
+
496
+ # Create training arguments with checkpointing
497
+ with tempfile.TemporaryDirectory() as tmpdir:
498
+ checkpoint_dir = Path(tmpdir) / "checkpoints"
499
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
500
+
501
+ training_args = TrainingArguments(
502
+ output_dir=str(checkpoint_dir),
503
+ num_train_epochs=self.config.num_epochs,
504
+ per_device_train_batch_size=self.config.batch_size,
505
+ per_device_eval_batch_size=self.config.batch_size,
506
+ learning_rate=self.config.learning_rate,
507
+ logging_steps=10,
508
+ eval_strategy="epoch" if eval_dataset is not None else "no",
509
+ save_strategy="epoch",
510
+ save_total_limit=1,
511
+ load_best_model_at_end=False,
512
+ report_to="none",
513
+ remove_unused_columns=False,
514
+ use_cpu=self.config.device == "cpu",
515
+ )
516
+
517
+ # Import here to avoid circular import
518
+ from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
519
+ MixedEffectsTrainer,
520
+ )
521
+
522
+ # Create trainer
523
+ trainer = MixedEffectsTrainer(
524
+ model=wrapped_model,
525
+ args=training_args,
526
+ train_dataset=train_dataset,
527
+ eval_dataset=eval_dataset,
528
+ data_collator=data_collator,
529
+ tokenizer=self.tokenizer,
530
+ random_effects_manager=self.random_effects,
531
+ compute_metrics=compute_metrics_fn,
532
+ )
533
+
534
+ # Train
535
+ train_result = trainer.train()
536
+
537
+ # Get training metrics
538
+ train_metrics = trainer.evaluate(eval_dataset=train_dataset)
539
+ metrics: dict[str, float] = {
540
+ "train_loss": float(train_result.training_loss),
541
+ "train_accuracy": train_metrics.get("eval_accuracy", 0.0),
542
+ "train_precision": train_metrics.get("eval_precision", 0.0),
543
+ "train_recall": train_metrics.get("eval_recall", 0.0),
544
+ "train_f1": train_metrics.get("eval_f1", 0.0),
545
+ }
546
+
547
+ # Get validation metrics if eval_dataset was provided
548
+ if eval_dataset is not None:
549
+ val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
550
+ metrics.update(
551
+ {
552
+ "val_accuracy": val_metrics.get("eval_accuracy", 0.0),
553
+ "val_precision": val_metrics.get("eval_precision", 0.0),
554
+ "val_recall": val_metrics.get("eval_recall", 0.0),
555
+ "val_f1": val_metrics.get("eval_f1", 0.0),
556
+ }
557
+ )
558
+
559
+ # Estimate variance components
560
+ if self.config.mixed_effects.estimate_variance_components:
561
+ var_comps = self.random_effects.estimate_variance_components()
562
+ if var_comps:
563
+ var_comp = var_comps.get("mu") or var_comps.get("slopes")
564
+ if var_comp:
565
+ self.variance_history.append(var_comp)
566
+ metrics["participant_variance"] = var_comp.variance
567
+ metrics["n_participants"] = var_comp.n_groups
568
+
569
+ self._is_fitted = True
570
+
571
+ return metrics
572
+
573
+ def _train_with_custom_loop(
574
+ self,
575
+ items: list[Item],
576
+ y_numeric: list[int],
577
+ participant_ids: list[str],
578
+ validation_items: list[Item] | None,
579
+ validation_labels: list[str] | None,
580
+ ) -> dict[str, float]:
581
+ """Train using custom training loop (for random_slopes mode).
582
+
583
+ Parameters
584
+ ----------
585
+ items : list[Item]
586
+ Training items.
587
+ y_numeric : list[int]
588
+ Numeric labels (class indices).
589
+ participant_ids : list[str]
590
+ Participant IDs.
591
+ validation_items : list[Item] | None
592
+ Validation items.
593
+ validation_labels : list[str] | None
594
+ Validation labels.
595
+
596
+ Returns
597
+ -------
598
+ dict[str, float]
599
+ Training metrics.
600
+ """
601
+ # Convert to tensor
602
+ y = torch.tensor(y_numeric, dtype=torch.long, device=self.config.device)
603
+
604
+ # Build optimizer parameters
605
+ params_to_optimize = list(self.encoder.parameters()) + list(
606
+ self.classifier_head.parameters()
607
+ )
608
+
609
+ # Add random effects parameters (for random_slopes)
610
+ if self.config.mixed_effects.mode == "random_slopes":
611
+ for head in self.random_effects.slopes.values():
612
+ params_to_optimize.extend(head.parameters())
613
+
614
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
615
+ criterion = nn.CrossEntropyLoss()
616
+
617
+ self.encoder.train()
618
+ self.classifier_head.train()
619
+
620
+ for _epoch in range(self.config.num_epochs):
621
+ n_batches = (
622
+ len(items) + self.config.batch_size - 1
623
+ ) // self.config.batch_size
624
+ epoch_loss = 0.0
625
+ epoch_correct = 0
626
+
627
+ for i in range(n_batches):
628
+ start_idx = i * self.config.batch_size
629
+ end_idx = min(start_idx + self.config.batch_size, len(items))
630
+
631
+ batch_items = items[start_idx:end_idx]
632
+ batch_labels = y[start_idx:end_idx]
633
+ batch_participant_ids = participant_ids[start_idx:end_idx]
634
+
635
+ embeddings = self._prepare_inputs(batch_items)
636
+
637
+ # Forward pass depends on mixed effects mode
638
+ if self.config.mixed_effects.mode == "fixed":
639
+ # Standard forward pass
640
+ logits = self.classifier_head(embeddings)
641
+
642
+ elif self.config.mixed_effects.mode == "random_intercepts":
643
+ # Fixed head + per-participant bias
644
+ logits = self.classifier_head(embeddings)
645
+ for j, pid in enumerate(batch_participant_ids):
646
+ bias = self.random_effects.get_intercepts(
647
+ pid,
648
+ n_classes=self.num_classes,
649
+ param_name="mu",
650
+ create_if_missing=True,
651
+ )
652
+ logits[j] = logits[j] + bias
653
+
654
+ elif self.config.mixed_effects.mode == "random_slopes":
655
+ # Per-participant head
656
+ logits_list = []
657
+ for j, pid in enumerate(batch_participant_ids):
658
+ participant_head = self.random_effects.get_slopes(
659
+ pid,
660
+ fixed_head=self.classifier_head,
661
+ create_if_missing=True,
662
+ )
663
+ logits_j = participant_head(embeddings[j : j + 1])
664
+ logits_list.append(logits_j)
665
+ logits = torch.cat(logits_list, dim=0)
666
+
667
+ # Data loss + prior regularization
668
+ loss_ce = criterion(logits, batch_labels)
669
+ loss_prior = self.random_effects.compute_prior_loss()
670
+ loss = loss_ce + loss_prior
671
+
672
+ optimizer.zero_grad()
673
+ loss.backward()
674
+ optimizer.step()
675
+
676
+ epoch_loss += loss.item()
677
+ predictions = torch.argmax(logits, dim=1)
678
+ epoch_correct += (predictions == batch_labels).sum().item()
679
+
680
+ epoch_acc = epoch_correct / len(items)
681
+ epoch_loss = epoch_loss / n_batches
682
+
683
+ self._is_fitted = True
684
+
685
+ metrics: dict[str, float] = {
686
+ "train_accuracy": epoch_acc,
687
+ "train_loss": epoch_loss,
688
+ }
689
+
690
+ # Estimate variance components
691
+ if self.config.mixed_effects.estimate_variance_components:
692
+ var_comps = self.random_effects.estimate_variance_components()
693
+ if var_comps:
694
+ var_comp = var_comps.get("mu") or var_comps.get("slopes")
695
+ if var_comp:
696
+ self.variance_history.append(var_comp)
697
+ metrics["participant_variance"] = var_comp.variance
698
+ metrics["n_participants"] = var_comp.n_groups
699
+
700
+ if validation_items is not None and validation_labels is not None:
701
+ self._validate_labels(validation_labels)
702
+
703
+ if len(validation_items) != len(validation_labels):
704
+ raise ValueError(
705
+ f"Number of validation items ({len(validation_items)}) "
706
+ f"must match number of validation labels ({len(validation_labels)})"
707
+ )
708
+
709
+ # Validation with placeholder participant_ids for mixed effects
710
+ if self.config.mixed_effects.mode == "fixed":
711
+ val_predictions = self.predict(validation_items, participant_ids=None)
712
+ else:
713
+ val_participant_ids = ["_validation_"] * len(validation_items)
714
+ val_predictions = self.predict(
715
+ validation_items, participant_ids=val_participant_ids
716
+ )
717
+ val_pred_labels = [p.predicted_class for p in val_predictions]
718
+ val_acc = sum(
719
+ pred == true
720
+ for pred, true in zip(val_pred_labels, validation_labels, strict=True)
721
+ ) / len(validation_labels)
722
+ metrics["val_accuracy"] = val_acc
723
+
724
+ return metrics
725
+
726
+ def _do_predict(
727
+ self, items: list[Item], participant_ids: list[str]
728
+ ) -> list[ModelPrediction]:
729
+ """Perform categorical model prediction.
730
+
731
+ Parameters
732
+ ----------
733
+ items : list[Item]
734
+ Items to predict.
735
+ participant_ids : list[str]
736
+ Normalized participant IDs.
737
+
738
+ Returns
739
+ -------
740
+ list[ModelPrediction]
741
+ Predictions.
742
+ """
743
+ self.encoder.eval()
744
+ self.classifier_head.eval()
745
+
746
+ with torch.no_grad():
747
+ embeddings = self._prepare_inputs(items)
748
+
749
+ # Forward pass depends on mixed effects mode
750
+ if self.config.mixed_effects.mode == "fixed":
751
+ logits = self.classifier_head(embeddings)
752
+
753
+ elif self.config.mixed_effects.mode == "random_intercepts":
754
+ logits = self.classifier_head(embeddings)
755
+ for i, pid in enumerate(participant_ids):
756
+ # Unknown participants: use prior mean (zero bias)
757
+ bias = self.random_effects.get_intercepts(
758
+ pid,
759
+ n_classes=self.num_classes,
760
+ param_name="mu",
761
+ create_if_missing=False,
762
+ )
763
+ logits[i] = logits[i] + bias
764
+
765
+ elif self.config.mixed_effects.mode == "random_slopes":
766
+ logits_list = []
767
+ for i, pid in enumerate(participant_ids):
768
+ # Unknown participants: use fixed head
769
+ participant_head = self.random_effects.get_slopes(
770
+ pid, fixed_head=self.classifier_head, create_if_missing=False
771
+ )
772
+ logits_i = participant_head(embeddings[i : i + 1])
773
+ logits_list.append(logits_i)
774
+ logits = torch.cat(logits_list, dim=0)
775
+
776
+ proba = torch.softmax(logits, dim=1).cpu().numpy()
777
+ pred_classes = torch.argmax(logits, dim=1).cpu().numpy()
778
+
779
+ predictions = []
780
+ for i, item in enumerate(items):
781
+ pred_label = self.category_names[pred_classes[i]]
782
+ prob_dict = {
783
+ cat: float(proba[i, idx]) for idx, cat in enumerate(self.category_names)
784
+ }
785
+ predictions.append(
786
+ ModelPrediction(
787
+ item_id=str(item.id),
788
+ probabilities=prob_dict,
789
+ predicted_class=pred_label,
790
+ confidence=float(proba[i, pred_classes[i]]),
791
+ )
792
+ )
793
+
794
+ return predictions
795
+
796
+ def _do_predict_proba(
797
+ self, items: list[Item], participant_ids: list[str]
798
+ ) -> np.ndarray:
799
+ """Perform categorical model probability prediction.
800
+
801
+ Parameters
802
+ ----------
803
+ items : list[Item]
804
+ Items to predict.
805
+ participant_ids : list[str]
806
+ Normalized participant IDs.
807
+
808
+ Returns
809
+ -------
810
+ np.ndarray
811
+ Probability array of shape (n_items, n_classes).
812
+ """
813
+ self.encoder.eval()
814
+ self.classifier_head.eval()
815
+
816
+ with torch.no_grad():
817
+ embeddings = self._prepare_inputs(items)
818
+
819
+ # Forward pass depends on mixed effects mode
820
+ if self.config.mixed_effects.mode == "fixed":
821
+ logits = self.classifier_head(embeddings)
822
+
823
+ elif self.config.mixed_effects.mode == "random_intercepts":
824
+ logits = self.classifier_head(embeddings)
825
+ for i, pid in enumerate(participant_ids):
826
+ bias = self.random_effects.get_intercepts(
827
+ pid,
828
+ n_classes=self.num_classes,
829
+ param_name="mu",
830
+ create_if_missing=False,
831
+ )
832
+ logits[i] = logits[i] + bias
833
+
834
+ elif self.config.mixed_effects.mode == "random_slopes":
835
+ logits_list = []
836
+ for i, pid in enumerate(participant_ids):
837
+ participant_head = self.random_effects.get_slopes(
838
+ pid, fixed_head=self.classifier_head, create_if_missing=False
839
+ )
840
+ logits_i = participant_head(embeddings[i : i + 1])
841
+ logits_list.append(logits_i)
842
+ logits = torch.cat(logits_list, dim=0)
843
+
844
+ proba = torch.softmax(logits, dim=1).cpu().numpy()
845
+
846
+ return proba
847
+
848
+ def _get_save_state(self) -> dict[str, object]:
849
+ """Get model-specific state to save.
850
+
851
+ Returns
852
+ -------
853
+ dict[str, object]
854
+ State dictionary.
855
+ """
856
+ return {
857
+ "num_classes": self.num_classes,
858
+ "category_names": self.category_names,
859
+ }
860
+
861
+ def _save_model_components(self, save_path: Path) -> None:
862
+ """Save model-specific components.
863
+
864
+ Parameters
865
+ ----------
866
+ save_path : Path
867
+ Directory to save to.
868
+ """
869
+ self.encoder.save_pretrained(save_path / "encoder")
870
+ self.tokenizer.save_pretrained(save_path / "encoder")
871
+
872
+ torch.save(
873
+ self.classifier_head.state_dict(),
874
+ save_path / "classifier_head.pt",
875
+ )
876
+
877
+ def _restore_training_state(self, config_dict: dict[str, object]) -> None:
878
+ """Restore model-specific training state.
879
+
880
+ Parameters
881
+ ----------
882
+ config_dict : dict[str, object]
883
+ Configuration dictionary with training state.
884
+ """
885
+ self.num_classes = config_dict.pop("num_classes")
886
+ self.category_names = config_dict.pop("category_names")
887
+
888
+ def _load_model_components(self, load_path: Path) -> None:
889
+ """Load model-specific components.
890
+
891
+ Parameters
892
+ ----------
893
+ load_path : Path
894
+ Directory to load from.
895
+ """
896
+ # Load config.json to reconstruct config
897
+ with open(load_path / "config.json") as f:
898
+ import json # noqa: PLC0415
899
+
900
+ config_dict = json.load(f)
901
+
902
+ # Reconstruct MixedEffectsConfig if needed
903
+ if "mixed_effects" in config_dict and isinstance(
904
+ config_dict["mixed_effects"], dict
905
+ ):
906
+ from bead.active_learning.config import MixedEffectsConfig # noqa: PLC0415
907
+
908
+ config_dict["mixed_effects"] = MixedEffectsConfig(
909
+ **config_dict["mixed_effects"]
910
+ )
911
+
912
+ self.config = CategoricalModelConfig(**config_dict)
913
+
914
+ self.encoder = AutoModel.from_pretrained(load_path / "encoder")
915
+ self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
916
+
917
+ self._initialize_classifier(self.num_classes)
918
+ self.classifier_head.load_state_dict(
919
+ torch.load(
920
+ load_path / "classifier_head.pt", map_location=self.config.device
921
+ )
922
+ )
923
+ self.classifier_head.to(self.config.device)
924
+
925
+ def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
926
+ """Get fixed head for random effects loading.
927
+
928
+ Returns
929
+ -------
930
+ nn.Module | None
931
+ Fixed head module.
932
+ """
933
+ return self.classifier_head
934
+
935
+ def _get_n_classes_for_random_effects(self) -> int:
936
+ """Get number of classes for random effects initialization.
937
+
938
+ Returns
939
+ -------
940
+ int
941
+ Number of classes.
942
+ """
943
+ return self.num_classes