bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,795 @@
1
+ """Multi-select model for selecting multiple options.
2
+
3
+ Expected architecture: Multi-label classification with sigmoid output per option.
4
+ Each option can be independently selected or not selected.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from transformers import AutoModel, AutoTokenizer
16
+
17
+ from bead.active_learning.config import MixedEffectsConfig, VarianceComponents
18
+ from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
19
+ from bead.active_learning.models.random_effects import RandomEffectsManager
20
+ from bead.config.active_learning import MultiSelectModelConfig
21
+ from bead.items.item import Item
22
+ from bead.items.item_template import ItemTemplate, TaskType
23
+
24
+ __all__ = ["MultiSelectModel"]
25
+
26
+
27
+ class MultiSelectModel(ActiveLearningModel):
28
+ """Model for multi_select tasks with N selectable options.
29
+
30
+ Uses multi-label classification where each option can be independently
31
+ selected or not selected. Applies sigmoid activation to each option's
32
+ logit and uses BCEWithLogitsLoss for training.
33
+
34
+ Parameters
35
+ ----------
36
+ config : MultiSelectModelConfig
37
+ Configuration object containing all model parameters.
38
+
39
+ Attributes
40
+ ----------
41
+ config : MultiSelectModelConfig
42
+ Model configuration.
43
+ tokenizer : AutoTokenizer
44
+ Transformer tokenizer.
45
+ encoder : AutoModel
46
+ Transformer encoder model.
47
+ classifier_head : nn.Sequential
48
+ Classification head (fixed effects head) - outputs N logits.
49
+ num_options : int | None
50
+ Number of selectable options (inferred from training data).
51
+ option_names : list[str] | None
52
+ Option names (e.g., ["option_a", "option_b", "option_c"]).
53
+ random_effects : RandomEffectsManager
54
+ Manager for participant-level random effects.
55
+ variance_history : list[VarianceComponents]
56
+ Variance component estimates over training (for diagnostics).
57
+ _is_fitted : bool
58
+ Whether model has been trained.
59
+
60
+ Examples
61
+ --------
62
+ >>> from uuid import uuid4
63
+ >>> from bead.items.item import Item
64
+ >>> from bead.config.active_learning import MultiSelectModelConfig
65
+ >>> items = [
66
+ ... Item(
67
+ ... item_template_id=uuid4(),
68
+ ... rendered_elements={
69
+ ... "option_a": "First option",
70
+ ... "option_b": "Second option",
71
+ ... "option_c": "Third option"
72
+ ... }
73
+ ... )
74
+ ... for _ in range(10)
75
+ ... ]
76
+ >>> # Labels as lists of selected options
77
+ >>> labels_list = [["option_a", "option_b"], ["option_c"], ["option_a"]]
78
+ >>> labels = labels_list * 3 + [["option_b"]]
79
+ >>> config = MultiSelectModelConfig( # doctest: +SKIP
80
+ ... num_epochs=1, batch_size=2, device="cpu"
81
+ ... )
82
+ >>> model = MultiSelectModel(config=config) # doctest: +SKIP
83
+ >>> # Convert labels to serialized format for train()
84
+ >>> label_strs = [json.dumps(sorted(lbls)) for lbls in labels] # doctest: +SKIP
85
+ >>> metrics = model.train(items, label_strs, participant_ids=None) # doctest: +SKIP
86
+
87
+ Notes
88
+ -----
89
+ This model uses BCEWithLogitsLoss (not CrossEntropyLoss) and applies
90
+ sigmoid activation to get independent probabilities for each option.
91
+ Random intercepts are bias vectors (one per option) that shift logits
92
+ independently for each participant.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ config: MultiSelectModelConfig | None = None,
98
+ ) -> None:
99
+ """Initialize multi-select model.
100
+
101
+ Parameters
102
+ ----------
103
+ config : MultiSelectModelConfig | None
104
+ Configuration object. If None, uses default configuration.
105
+ """
106
+ self.config = config or MultiSelectModelConfig()
107
+
108
+ # Validate mixed_effects configuration
109
+ super().__init__(self.config)
110
+
111
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
112
+ self.encoder = AutoModel.from_pretrained(self.config.model_name)
113
+
114
+ self.num_options: int | None = None
115
+ self.option_names: list[str] | None = None
116
+ self.classifier_head: nn.Sequential | None = None
117
+ self._is_fitted = False
118
+
119
+ # Initialize random effects manager
120
+ self.random_effects: RandomEffectsManager | None = None
121
+ self.variance_history: list[VarianceComponents] = []
122
+
123
+ self.encoder.to(self.config.device)
124
+
125
+ @property
126
+ def supported_task_types(self) -> list[TaskType]:
127
+ """Get supported task types.
128
+
129
+ Returns
130
+ -------
131
+ list[TaskType]
132
+ List containing "multi_select".
133
+ """
134
+ return ["multi_select"]
135
+
136
+ def validate_item_compatibility(
137
+ self, item: Item, item_template: ItemTemplate
138
+ ) -> None:
139
+ """Validate item is compatible with multi-select model.
140
+
141
+ Parameters
142
+ ----------
143
+ item : Item
144
+ Item to validate.
145
+ item_template : ItemTemplate
146
+ Template the item was constructed from.
147
+
148
+ Raises
149
+ ------
150
+ ValueError
151
+ If task_type is not "multi_select".
152
+ """
153
+ if item_template.task_type != "multi_select":
154
+ raise ValueError(
155
+ f"Expected task_type 'multi_select', got '{item_template.task_type}'"
156
+ )
157
+
158
+ def _initialize_classifier(self, num_options: int) -> None:
159
+ """Initialize classification head for given number of options.
160
+
161
+ Parameters
162
+ ----------
163
+ num_options : int
164
+ Number of selectable options (output units).
165
+ """
166
+ hidden_size = self.encoder.config.hidden_size
167
+
168
+ if self.config.encoder_mode == "dual_encoder":
169
+ input_size = hidden_size * num_options
170
+ else:
171
+ input_size = hidden_size
172
+
173
+ self.classifier_head = nn.Sequential(
174
+ nn.Linear(input_size, 256),
175
+ nn.ReLU(),
176
+ nn.Dropout(0.1),
177
+ nn.Linear(256, num_options), # N independent outputs
178
+ )
179
+ self.classifier_head.to(self.config.device)
180
+
181
+ def _encode_single(self, texts: list[str]) -> torch.Tensor:
182
+ """Encode texts using single encoder strategy.
183
+
184
+ Concatenates all option texts with [SEP] tokens and encodes once.
185
+
186
+ Parameters
187
+ ----------
188
+ texts : list[str]
189
+ List of concatenated option texts for each item.
190
+
191
+ Returns
192
+ -------
193
+ torch.Tensor
194
+ Encoded representations of shape (batch_size, hidden_size).
195
+ """
196
+ encodings = self.tokenizer(
197
+ texts,
198
+ padding=True,
199
+ truncation=True,
200
+ max_length=self.config.max_length,
201
+ return_tensors="pt",
202
+ )
203
+ encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
204
+
205
+ outputs = self.encoder(**encodings)
206
+ return outputs.last_hidden_state[:, 0, :]
207
+
208
+ def _encode_dual(self, options_per_item: list[list[str]]) -> torch.Tensor:
209
+ """Encode texts using dual encoder strategy.
210
+
211
+ Encodes each option separately and concatenates embeddings.
212
+
213
+ Parameters
214
+ ----------
215
+ options_per_item : list[list[str]]
216
+ List of option lists. Each inner list contains option texts for one item.
217
+
218
+ Returns
219
+ -------
220
+ torch.Tensor
221
+ Concatenated encodings of shape (batch_size, hidden_size * num_options).
222
+ """
223
+ all_embeddings = []
224
+
225
+ for options in options_per_item:
226
+ option_embeddings = []
227
+ for option_text in options:
228
+ encodings = self.tokenizer(
229
+ [option_text],
230
+ padding=True,
231
+ truncation=True,
232
+ max_length=self.config.max_length,
233
+ return_tensors="pt",
234
+ )
235
+ encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
236
+
237
+ outputs = self.encoder(**encodings)
238
+ cls_embedding = outputs.last_hidden_state[0, 0, :]
239
+ option_embeddings.append(cls_embedding)
240
+
241
+ concatenated = torch.cat(option_embeddings, dim=0)
242
+ all_embeddings.append(concatenated)
243
+
244
+ return torch.stack(all_embeddings)
245
+
246
+ def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
247
+ """Prepare inputs for encoding based on encoder mode.
248
+
249
+ For multi-select tasks, uses all options from rendered_elements.
250
+
251
+ Parameters
252
+ ----------
253
+ items : list[Item]
254
+ Items to encode.
255
+
256
+ Returns
257
+ -------
258
+ torch.Tensor
259
+ Encoded representations.
260
+ """
261
+ if self.option_names is None:
262
+ raise ValueError("Model not initialized. Call train() first.")
263
+
264
+ if self.config.encoder_mode == "single_encoder":
265
+ texts = []
266
+ for item in items:
267
+ option_texts = [
268
+ item.rendered_elements.get(opt, "") for opt in self.option_names
269
+ ]
270
+ concatenated = " [SEP] ".join(option_texts)
271
+ texts.append(concatenated)
272
+ return self._encode_single(texts)
273
+ else:
274
+ options_per_item = []
275
+ for item in items:
276
+ option_texts = [
277
+ item.rendered_elements.get(opt, "") for opt in self.option_names
278
+ ]
279
+ options_per_item.append(option_texts)
280
+ return self._encode_dual(options_per_item)
281
+
282
+ def _parse_multi_select_labels(self, label_str: str) -> list[str]:
283
+ """Parse multi-select label from JSON string.
284
+
285
+ Parameters
286
+ ----------
287
+ label_str : str
288
+ JSON-serialized list of selected options.
289
+
290
+ Returns
291
+ -------
292
+ list[str]
293
+ List of selected option names.
294
+ """
295
+ try:
296
+ selected = json.loads(label_str)
297
+ if not isinstance(selected, list):
298
+ raise ValueError(
299
+ f"Label must be JSON list of option names, got {type(selected)}"
300
+ )
301
+ return selected
302
+ except json.JSONDecodeError as e:
303
+ raise ValueError(
304
+ f"Label must be valid JSON list of selected options. "
305
+ f"Got: {label_str!r}. Error: {e}"
306
+ ) from e
307
+
308
+ def _prepare_training_data(
309
+ self,
310
+ items: list[Item],
311
+ labels: list[str],
312
+ participant_ids: list[str],
313
+ validation_items: list[Item] | None,
314
+ validation_labels: list[str] | None,
315
+ ) -> tuple[
316
+ list[Item], torch.Tensor, list[str], list[Item] | None, torch.Tensor | None
317
+ ]:
318
+ """Prepare training data for multi-select model.
319
+
320
+ Parameters
321
+ ----------
322
+ items : list[Item]
323
+ Training items.
324
+ labels : list[str]
325
+ Training labels (JSON strings of selected options).
326
+ participant_ids : list[str]
327
+ Normalized participant IDs.
328
+ validation_items : list[Item] | None
329
+ Validation items.
330
+ validation_labels : list[str] | None
331
+ Validation labels.
332
+
333
+ Returns
334
+ -------
335
+ tuple
336
+ Prepared items, labels, participant_ids, val items, val labels.
337
+ """
338
+ if not items:
339
+ raise ValueError("Cannot train with empty items list")
340
+
341
+ # Infer option names from first item
342
+ self.option_names = sorted(items[0].rendered_elements.keys())
343
+ self.num_options = len(self.option_names)
344
+ option_to_idx = {opt: idx for idx, opt in enumerate(self.option_names)}
345
+
346
+ # Parse labels and convert to binary matrix
347
+ y = torch.zeros(
348
+ (len(items), self.num_options), dtype=torch.float, device=self.config.device
349
+ )
350
+ for i, label_str in enumerate(labels):
351
+ selected_options = self._parse_multi_select_labels(label_str)
352
+ for opt in selected_options:
353
+ if opt not in option_to_idx:
354
+ raise ValueError(
355
+ f"Invalid option {opt!r} in label. "
356
+ f"Valid options: {self.option_names}"
357
+ )
358
+ y[i, option_to_idx[opt]] = 1.0
359
+
360
+ self._initialize_classifier(self.num_options)
361
+
362
+ # Convert validation labels if provided
363
+ val_y = None
364
+ if validation_items is not None and validation_labels is not None:
365
+ if len(validation_items) != len(validation_labels):
366
+ raise ValueError(
367
+ f"Number of validation items ({len(validation_items)}) "
368
+ f"must match number of validation labels ({len(validation_labels)})"
369
+ )
370
+ val_y = torch.zeros(
371
+ (len(validation_items), self.num_options),
372
+ dtype=torch.float,
373
+ device=self.config.device,
374
+ )
375
+ for i, label_str in enumerate(validation_labels):
376
+ selected_options = self._parse_multi_select_labels(label_str)
377
+ for opt in selected_options:
378
+ if opt not in option_to_idx:
379
+ raise ValueError(
380
+ f"Invalid option {opt!r} in validation label. "
381
+ f"Valid options: {self.option_names}"
382
+ )
383
+ val_y[i, option_to_idx[opt]] = 1.0
384
+
385
+ return items, y, participant_ids, validation_items, val_y
386
+
387
+ def _initialize_random_effects(self, n_classes: int) -> None:
388
+ """Initialize random effects manager.
389
+
390
+ Parameters
391
+ ----------
392
+ n_classes : int
393
+ Number of classes (num_options for multi-select).
394
+ """
395
+ self.random_effects = RandomEffectsManager(
396
+ self.config.mixed_effects, n_classes=n_classes
397
+ )
398
+
399
+ def _do_training(
400
+ self,
401
+ items: list[Item],
402
+ labels_numeric: torch.Tensor,
403
+ participant_ids: list[str],
404
+ validation_items: list[Item] | None,
405
+ validation_labels_numeric: torch.Tensor | None,
406
+ ) -> dict[str, float]:
407
+ """Perform multi-select model training.
408
+
409
+ Parameters
410
+ ----------
411
+ items : list[Item]
412
+ Training items.
413
+ labels_numeric : torch.Tensor
414
+ Binary label tensor of shape (n_items, n_options).
415
+ participant_ids : list[str]
416
+ Participant IDs.
417
+ validation_items : list[Item] | None
418
+ Validation items.
419
+ validation_labels_numeric : torch.Tensor | None
420
+ Validation label tensor.
421
+
422
+ Returns
423
+ -------
424
+ dict[str, float]
425
+ Training metrics.
426
+ """
427
+ y = labels_numeric
428
+
429
+ # Build optimizer parameters based on mode
430
+ params_to_optimize = list(self.encoder.parameters()) + list(
431
+ self.classifier_head.parameters()
432
+ )
433
+
434
+ # Add random effects parameters
435
+ if self.config.mixed_effects.mode == "random_intercepts":
436
+ for param_dict in self.random_effects.intercepts.values():
437
+ params_to_optimize.extend(param_dict.values())
438
+ elif self.config.mixed_effects.mode == "random_slopes":
439
+ for head in self.random_effects.slopes.values():
440
+ params_to_optimize.extend(head.parameters())
441
+
442
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
443
+ # BCE with Logits Loss for multi-label classification
444
+ criterion = nn.BCEWithLogitsLoss()
445
+
446
+ self.encoder.train()
447
+ self.classifier_head.train()
448
+
449
+ epoch_acc = 0.0
450
+ epoch_loss = 0.0
451
+
452
+ for _epoch in range(self.config.num_epochs):
453
+ n_batches = (
454
+ len(items) + self.config.batch_size - 1
455
+ ) // self.config.batch_size
456
+ epoch_loss = 0.0
457
+ epoch_correct_predictions = 0
458
+ epoch_total_predictions = 0
459
+
460
+ for i in range(n_batches):
461
+ start_idx = i * self.config.batch_size
462
+ end_idx = min(start_idx + self.config.batch_size, len(items))
463
+
464
+ batch_items = items[start_idx:end_idx]
465
+ batch_labels = y[start_idx:end_idx]
466
+ batch_participant_ids = participant_ids[start_idx:end_idx]
467
+
468
+ embeddings = self._prepare_inputs(batch_items)
469
+
470
+ # Forward pass depends on mixed effects mode
471
+ if self.config.mixed_effects.mode == "fixed":
472
+ # Standard forward pass
473
+ logits = self.classifier_head(embeddings)
474
+
475
+ elif self.config.mixed_effects.mode == "random_intercepts":
476
+ # Fixed head + per-participant bias (independent per option)
477
+ logits = self.classifier_head(embeddings)
478
+ for j, pid in enumerate(batch_participant_ids):
479
+ bias = self.random_effects.get_intercepts(
480
+ pid,
481
+ n_classes=self.num_options,
482
+ param_name="mu",
483
+ create_if_missing=True,
484
+ )
485
+ logits[j] = logits[j] + bias
486
+
487
+ elif self.config.mixed_effects.mode == "random_slopes":
488
+ # Per-participant head
489
+ logits_list = []
490
+ for j, pid in enumerate(batch_participant_ids):
491
+ participant_head = self.random_effects.get_slopes(
492
+ pid,
493
+ fixed_head=self.classifier_head,
494
+ create_if_missing=True,
495
+ )
496
+ logits_j = participant_head(embeddings[j : j + 1])
497
+ logits_list.append(logits_j)
498
+ logits = torch.cat(logits_list, dim=0)
499
+
500
+ # Data loss + prior regularization
501
+ loss_bce = criterion(logits, batch_labels)
502
+ loss_prior = self.random_effects.compute_prior_loss()
503
+ loss = loss_bce + loss_prior
504
+
505
+ optimizer.zero_grad()
506
+ loss.backward()
507
+ optimizer.step()
508
+
509
+ epoch_loss += loss.item()
510
+
511
+ # Predictions: threshold at 0.5 on sigmoid(logits)
512
+ predictions = (torch.sigmoid(logits) > 0.5).float()
513
+ # Hamming accuracy: fraction of correct predictions (per option)
514
+ batch_correct = (predictions == batch_labels).sum().item()
515
+ batch_total = batch_labels.numel()
516
+ epoch_correct_predictions += batch_correct
517
+ epoch_total_predictions += batch_total
518
+
519
+ # Hamming accuracy: average over all (item, option) pairs
520
+ epoch_acc = epoch_correct_predictions / epoch_total_predictions
521
+ epoch_loss = epoch_loss / n_batches
522
+
523
+ metrics: dict[str, float] = {
524
+ "train_accuracy": epoch_acc,
525
+ "train_loss": epoch_loss,
526
+ }
527
+
528
+ # Add validation accuracy if validation data provided
529
+ if validation_items is not None and validation_labels_numeric is not None:
530
+ # Validation with placeholder participant_ids for mixed effects
531
+ if self.config.mixed_effects.mode == "fixed":
532
+ val_participant_ids = ["_fixed_"] * len(validation_items)
533
+ else:
534
+ val_participant_ids = ["_validation_"] * len(validation_items)
535
+ val_predictions = self._do_predict(validation_items, val_participant_ids)
536
+
537
+ # Parse validation labels
538
+ val_labels_parsed = []
539
+ for i in range(validation_labels_numeric.shape[0]):
540
+ selected = [
541
+ self.option_names[j]
542
+ for j in range(self.num_options)
543
+ if validation_labels_numeric[i, j] > 0.5
544
+ ]
545
+ val_labels_parsed.append(set(selected))
546
+
547
+ # Compute Hamming accuracy
548
+ val_correct = 0
549
+ val_total = 0
550
+ for pred, true_set in zip(val_predictions, val_labels_parsed, strict=True):
551
+ # pred.predicted_class is JSON string of selected options
552
+ pred_set = set(json.loads(pred.predicted_class))
553
+ for opt in self.option_names:
554
+ if (opt in pred_set) == (opt in true_set):
555
+ val_correct += 1
556
+ val_total += 1
557
+
558
+ val_acc = val_correct / val_total
559
+ metrics["val_accuracy"] = val_acc
560
+
561
+ return metrics
562
+
563
+ def _do_predict(
564
+ self, items: list[Item], participant_ids: list[str]
565
+ ) -> list[ModelPrediction]:
566
+ """Perform multi-select model prediction.
567
+
568
+ Parameters
569
+ ----------
570
+ items : list[Item]
571
+ Items to predict.
572
+ participant_ids : list[str]
573
+ Normalized participant IDs.
574
+
575
+ Returns
576
+ -------
577
+ list[ModelPrediction]
578
+ Predictions.
579
+ """
580
+ self.encoder.eval()
581
+ self.classifier_head.eval()
582
+
583
+ with torch.no_grad():
584
+ embeddings = self._prepare_inputs(items)
585
+
586
+ # Forward pass depends on mixed effects mode
587
+ if self.config.mixed_effects.mode == "fixed":
588
+ logits = self.classifier_head(embeddings)
589
+
590
+ elif self.config.mixed_effects.mode == "random_intercepts":
591
+ logits = self.classifier_head(embeddings)
592
+ for i, pid in enumerate(participant_ids):
593
+ # Unknown participants: use prior mean (zero bias)
594
+ bias = self.random_effects.get_intercepts(
595
+ pid,
596
+ n_classes=self.num_options,
597
+ param_name="mu",
598
+ create_if_missing=False,
599
+ )
600
+ logits[i] = logits[i] + bias
601
+
602
+ elif self.config.mixed_effects.mode == "random_slopes":
603
+ logits_list = []
604
+ for i, pid in enumerate(participant_ids):
605
+ # Unknown participants: use fixed head
606
+ participant_head = self.random_effects.get_slopes(
607
+ pid, fixed_head=self.classifier_head, create_if_missing=False
608
+ )
609
+ logits_i = participant_head(embeddings[i : i + 1])
610
+ logits_list.append(logits_i)
611
+ logits = torch.cat(logits_list, dim=0)
612
+
613
+ # Compute probabilities using sigmoid
614
+ proba = torch.sigmoid(logits).cpu().numpy() # (n_items, n_options)
615
+ pred_binary = proba > 0.5 # Threshold at 0.5
616
+
617
+ predictions = []
618
+ for i, item in enumerate(items):
619
+ # Determine selected options
620
+ selected_options = [
621
+ self.option_names[j]
622
+ for j in range(self.num_options)
623
+ if pred_binary[i, j]
624
+ ]
625
+
626
+ # Build probability dict: {option: probability}
627
+ prob_dict = {
628
+ opt: float(proba[i, idx]) for idx, opt in enumerate(self.option_names)
629
+ }
630
+
631
+ # Confidence: average probability of selected options (or 0.5 if none)
632
+ if selected_options:
633
+ option_probs = [
634
+ proba[i, self.option_names.index(opt)] for opt in selected_options
635
+ ]
636
+ confidence = float(np.mean(option_probs))
637
+ else:
638
+ confidence = 0.5 # Neutral confidence when nothing selected
639
+
640
+ predictions.append(
641
+ ModelPrediction(
642
+ item_id=str(item.id),
643
+ probabilities=prob_dict,
644
+ predicted_class=json.dumps(sorted(selected_options)),
645
+ confidence=confidence,
646
+ )
647
+ )
648
+
649
+ return predictions
650
+
651
+ def _do_predict_proba(
652
+ self, items: list[Item], participant_ids: list[str]
653
+ ) -> np.ndarray:
654
+ """Perform multi-select model probability prediction.
655
+
656
+ Parameters
657
+ ----------
658
+ items : list[Item]
659
+ Items to predict.
660
+ participant_ids : list[str]
661
+ Normalized participant IDs.
662
+
663
+ Returns
664
+ -------
665
+ np.ndarray
666
+ Probability array of shape (n_items, n_options).
667
+ """
668
+ self.encoder.eval()
669
+ self.classifier_head.eval()
670
+
671
+ with torch.no_grad():
672
+ embeddings = self._prepare_inputs(items)
673
+
674
+ # Forward pass depends on mixed effects mode
675
+ if self.config.mixed_effects.mode == "fixed":
676
+ logits = self.classifier_head(embeddings)
677
+
678
+ elif self.config.mixed_effects.mode == "random_intercepts":
679
+ logits = self.classifier_head(embeddings)
680
+ for i, pid in enumerate(participant_ids):
681
+ bias = self.random_effects.get_intercepts(
682
+ pid,
683
+ n_classes=self.num_options,
684
+ param_name="mu",
685
+ create_if_missing=False,
686
+ )
687
+ logits[i] = logits[i] + bias
688
+
689
+ elif self.config.mixed_effects.mode == "random_slopes":
690
+ logits_list = []
691
+ for i, pid in enumerate(participant_ids):
692
+ participant_head = self.random_effects.get_slopes(
693
+ pid, fixed_head=self.classifier_head, create_if_missing=False
694
+ )
695
+ logits_i = participant_head(embeddings[i : i + 1])
696
+ logits_list.append(logits_i)
697
+ logits = torch.cat(logits_list, dim=0)
698
+
699
+ # Compute probabilities using sigmoid
700
+ proba = torch.sigmoid(logits).cpu().numpy()
701
+
702
+ return proba
703
+
704
+ def _get_save_state(self) -> dict[str, object]:
705
+ """Get model-specific state to save.
706
+
707
+ Returns
708
+ -------
709
+ dict[str, object]
710
+ State dictionary.
711
+ """
712
+ return {
713
+ "num_options": self.num_options,
714
+ "option_names": self.option_names,
715
+ }
716
+
717
+ def _save_model_components(self, save_path: Path) -> None:
718
+ """Save model-specific components.
719
+
720
+ Parameters
721
+ ----------
722
+ save_path : Path
723
+ Directory to save to.
724
+ """
725
+ self.encoder.save_pretrained(save_path / "encoder")
726
+ self.tokenizer.save_pretrained(save_path / "encoder")
727
+
728
+ torch.save(
729
+ self.classifier_head.state_dict(),
730
+ save_path / "classifier_head.pt",
731
+ )
732
+
733
+ def _restore_training_state(self, config_dict: dict[str, object]) -> None:
734
+ """Restore model-specific training state.
735
+
736
+ Parameters
737
+ ----------
738
+ config_dict : dict[str, object]
739
+ Configuration dictionary with training state.
740
+ """
741
+ self.num_options = config_dict.pop("num_options")
742
+ self.option_names = config_dict.pop("option_names")
743
+
744
+ def _load_model_components(self, load_path: Path) -> None:
745
+ """Load model-specific components.
746
+
747
+ Parameters
748
+ ----------
749
+ load_path : Path
750
+ Directory to load from.
751
+ """
752
+ # Load config.json to reconstruct config
753
+ with open(load_path / "config.json") as f:
754
+ config_dict = json.load(f)
755
+
756
+ # Reconstruct MixedEffectsConfig if needed
757
+ if "mixed_effects" in config_dict and isinstance(
758
+ config_dict["mixed_effects"], dict
759
+ ):
760
+ config_dict["mixed_effects"] = MixedEffectsConfig(
761
+ **config_dict["mixed_effects"]
762
+ )
763
+
764
+ self.config = MultiSelectModelConfig(**config_dict)
765
+
766
+ self.encoder = AutoModel.from_pretrained(load_path / "encoder")
767
+ self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
768
+
769
+ self._initialize_classifier(self.num_options)
770
+ self.classifier_head.load_state_dict(
771
+ torch.load(
772
+ load_path / "classifier_head.pt", map_location=self.config.device
773
+ )
774
+ )
775
+ self.classifier_head.to(self.config.device)
776
+
777
+ def _get_random_effects_fixed_head(self) -> nn.Sequential | None:
778
+ """Get fixed head for random effects loading.
779
+
780
+ Returns
781
+ -------
782
+ nn.Sequential | None
783
+ Fixed head module.
784
+ """
785
+ return self.classifier_head
786
+
787
+ def _get_n_classes_for_random_effects(self) -> int:
788
+ """Get number of classes for random effects initialization.
789
+
790
+ Returns
791
+ -------
792
+ int
793
+ Number of options.
794
+ """
795
+ return self.num_options