bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,852 @@
1
+ """Base interfaces for active learning models with mixed effects support.
2
+
3
+ This module implements Generalized Linear Mixed Effects Models (GLMMs) following
4
+ the standard formulation:
5
+
6
+ y = Xβ + Zu + ε
7
+
8
+ Where:
9
+ - Xβ: Fixed effects (population-level parameters, shared across all groups)
10
+ - Zu: Random effects (group-specific parameters, e.g., per-participant)
11
+ - u ~ N(0, G): Random effects with variance-covariance matrix G
12
+ - ε: Residuals
13
+
14
+ The implementation supports three modeling modes:
15
+ 1. Fixed effects: Standard model, ignores grouping structure
16
+ 2. Random intercepts: Per-group biases (Zu = bias vector per group)
17
+ 3. Random slopes: Per-group model parameters (Zu = separate model head per group)
18
+
19
+ References
20
+ ----------
21
+ - Bates et al. (2015). "Fitting Linear Mixed-Effects Models using lme4"
22
+ - Simchoni & Rosset (2022). "Integrating Random Effects in Deep Neural Networks"
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import json
28
+ from abc import ABC, abstractmethod
29
+ from collections import Counter
30
+ from pathlib import Path
31
+ from typing import TYPE_CHECKING
32
+
33
+ import numpy as np
34
+
35
+ from bead.active_learning.config import (
36
+ MixedEffectsConfig,
37
+ RandomEffectsSpec,
38
+ VarianceComponents,
39
+ )
40
+ from bead.data.base import BeadBaseModel
41
+ from bead.items.item import Item
42
+
43
+ if TYPE_CHECKING:
44
+ import torch
45
+
46
+ from bead.items.item_template import ItemTemplate, TaskType
47
+
48
+ __all__ = [
49
+ "ActiveLearningModel",
50
+ "ModelPrediction",
51
+ "MixedEffectsConfig",
52
+ "VarianceComponents",
53
+ "RandomEffectsSpec",
54
+ ]
55
+
56
+
57
+ class ModelPrediction(BeadBaseModel):
58
+ """Prediction output for a single item.
59
+
60
+ Attributes
61
+ ----------
62
+ item_id : str
63
+ Unique identifier for the item.
64
+ probabilities : dict[str, float]
65
+ Predicted probabilities for each class/option.
66
+ Keys are option names (e.g., "option_a", "option_b") or class labels.
67
+ predicted_class : str
68
+ The predicted class/option with highest probability.
69
+ confidence : float
70
+ Confidence score (max probability).
71
+
72
+ Examples
73
+ --------
74
+ >>> prediction = ModelPrediction(
75
+ ... item_id="abc123",
76
+ ... probabilities={"option_a": 0.7, "option_b": 0.3},
77
+ ... predicted_class="option_a",
78
+ ... confidence=0.7
79
+ ... )
80
+ >>> prediction.predicted_class
81
+ 'option_a'
82
+ """
83
+
84
+ item_id: str
85
+ probabilities: dict[str, float]
86
+ predicted_class: str
87
+ confidence: float
88
+
89
+
90
+ class ActiveLearningModel(ABC):
91
+ """Base class for all active learning models with mixed effects support.
92
+
93
+ Implements GLMM-based active learning: y = Xβ + Zu + ε
94
+
95
+ All models must:
96
+ 1. Support mixed effects (fixed, random_intercepts, random_slopes modes)
97
+ 2. Accept participant_ids in train/predict/predict_proba (None for fixed effects)
98
+ 3. Validate items match supported task types
99
+ 4. Track variance components (if estimate_variance_components=True)
100
+
101
+ Attributes
102
+ ----------
103
+ config : dict[str, str | int | float | bool | None] | BeadBaseModel
104
+ Model configuration (task-type-specific).
105
+ Must include a `mixed_effects: MixedEffectsConfig` field.
106
+ supported_task_types : list[TaskType]
107
+ List of task types this model can handle.
108
+
109
+ Examples
110
+ --------
111
+ >>> class MyModel(ActiveLearningModel):
112
+ ... def __init__(self, config):
113
+ ... super().__init__(config) # Validates mixed_effects field
114
+ ... @property
115
+ ... def supported_task_types(self):
116
+ ... return ["forced_choice"]
117
+ ... def validate_item_compatibility(self, item, item_template):
118
+ ... pass
119
+ ... def train(self, items, labels, participant_ids):
120
+ ... return {}
121
+ ... def predict(self, items, participant_ids):
122
+ ... return []
123
+ ... def predict_proba(self, items, participant_ids):
124
+ ... return np.array([])
125
+ ... def save(self, path):
126
+ ... pass
127
+ ... def load(self, path):
128
+ ... pass
129
+ """
130
+
131
+ def __init__(
132
+ self, config: dict[str, str | int | float | bool | None] | BeadBaseModel
133
+ ) -> None:
134
+ """Initialize model with configuration.
135
+
136
+ Parameters
137
+ ----------
138
+ config : Any
139
+ Model configuration. Must have a `mixed_effects` field of type
140
+ MixedEffectsConfig.
141
+
142
+ Raises
143
+ ------
144
+ ValueError
145
+ If config is invalid or missing required fields.
146
+
147
+ Examples
148
+ --------
149
+ >>> from bead.config.active_learning import ForcedChoiceModelConfig
150
+ >>> config = ForcedChoiceModelConfig(
151
+ ... n_classes=2,
152
+ ... mixed_effects=MixedEffectsConfig(mode='fixed')
153
+ ... )
154
+ >>> model = ForcedChoiceModel(config) # doctest: +SKIP
155
+ """
156
+ self.config = config
157
+
158
+ # Validate mixed_effects field exists
159
+ if not hasattr(config, "mixed_effects"):
160
+ raise ValueError(
161
+ f"Model config must have a 'mixed_effects' field of type "
162
+ f"MixedEffectsConfig, but {type(config).__name__} has no such field. "
163
+ f"Add: mixed_effects: MixedEffectsConfig = "
164
+ f"Field(default_factory=MixedEffectsConfig)"
165
+ )
166
+
167
+ # Validate mixed_effects is correct type
168
+ if not isinstance(config.mixed_effects, MixedEffectsConfig):
169
+ raise ValueError(
170
+ f"config.mixed_effects must be MixedEffectsConfig, but got "
171
+ f"{type(config.mixed_effects).__name__}. "
172
+ f"Ensure the field is properly typed: mixed_effects: MixedEffectsConfig"
173
+ )
174
+
175
+ def _validate_items_labels_length(
176
+ self, items: list[Item], labels: list[str]
177
+ ) -> None:
178
+ """Validate that items and labels have the same length.
179
+
180
+ Parameters
181
+ ----------
182
+ items : list[Item]
183
+ Training items.
184
+ labels : list[str]
185
+ Training labels.
186
+
187
+ Raises
188
+ ------
189
+ ValueError
190
+ If items and labels have different lengths.
191
+ """
192
+ if len(items) != len(labels):
193
+ raise ValueError(
194
+ f"Number of items ({len(items)}) must match "
195
+ f"number of labels ({len(labels)})"
196
+ )
197
+
198
+ def _validate_participant_ids_required(
199
+ self, participant_ids: list[str] | None, mode: str
200
+ ) -> None:
201
+ """Validate that participant_ids is provided when required.
202
+
203
+ Parameters
204
+ ----------
205
+ participant_ids : list[str] | None
206
+ Participant IDs to validate.
207
+ mode : str
208
+ Mixed effects mode ('fixed', 'random_intercepts', 'random_slopes').
209
+
210
+ Raises
211
+ ------
212
+ ValueError
213
+ If participant_ids is None when mode requires it.
214
+ """
215
+ if participant_ids is None and mode != "fixed":
216
+ raise ValueError(
217
+ f"participant_ids is required when mode='{mode}'. "
218
+ f"For fixed effects, set mode='fixed' in config. "
219
+ f"For mixed effects, provide participant_ids as list[str]."
220
+ )
221
+
222
+ def _validate_participant_ids_length(
223
+ self, items: list[Item], participant_ids: list[str]
224
+ ) -> None:
225
+ """Validate that items and participant_ids have the same length.
226
+
227
+ Parameters
228
+ ----------
229
+ items : list[Item]
230
+ Training items.
231
+ participant_ids : list[str]
232
+ Participant IDs.
233
+
234
+ Raises
235
+ ------
236
+ ValueError
237
+ If items and participant_ids have different lengths.
238
+ """
239
+ if len(items) != len(participant_ids):
240
+ raise ValueError(
241
+ f"Length mismatch: {len(items)} items != {len(participant_ids)} "
242
+ f"participant_ids. participant_ids must have same length as items."
243
+ )
244
+
245
+ def _validate_participant_ids_not_empty(self, participant_ids: list[str]) -> None:
246
+ """Validate that participant_ids does not contain empty strings.
247
+
248
+ Parameters
249
+ ----------
250
+ participant_ids : list[str]
251
+ Participant IDs to validate.
252
+
253
+ Raises
254
+ ------
255
+ ValueError
256
+ If participant_ids contains empty strings.
257
+ """
258
+ if any(not pid for pid in participant_ids):
259
+ raise ValueError(
260
+ "participant_ids cannot contain empty strings. "
261
+ "Ensure all participants have valid identifiers."
262
+ )
263
+
264
+ def _normalize_participant_ids(
265
+ self,
266
+ participant_ids: list[str] | None,
267
+ items: list[Item],
268
+ mode: str,
269
+ ) -> list[str]:
270
+ """Normalize participant_ids based on mode.
271
+
272
+ For fixed mode, replaces participant_ids with dummy values.
273
+ For mixed effects modes, validates and returns participant_ids as-is.
274
+
275
+ Parameters
276
+ ----------
277
+ participant_ids : list[str] | None
278
+ Participant IDs (may be None for fixed mode).
279
+ items : list[Item]
280
+ Training items (used to determine length).
281
+ mode : str
282
+ Mixed effects mode ('fixed', 'random_intercepts', 'random_slopes').
283
+
284
+ Returns
285
+ -------
286
+ list[str]
287
+ Normalized participant IDs (all "_fixed_" for fixed mode).
288
+
289
+ Raises
290
+ ------
291
+ ValueError
292
+ If participant_ids is None when mode requires it.
293
+ ValueError
294
+ If items and participant_ids have different lengths.
295
+ ValueError
296
+ If participant_ids contains empty strings.
297
+ """
298
+ import warnings # noqa: PLC0415
299
+
300
+ if participant_ids is None:
301
+ if mode != "fixed":
302
+ self._validate_participant_ids_required(participant_ids, mode)
303
+ return ["_fixed_"] * len(items)
304
+
305
+ # Validate length and empty strings before normalizing
306
+ self._validate_participant_ids_length(items, participant_ids)
307
+ self._validate_participant_ids_not_empty(participant_ids)
308
+
309
+ if mode == "fixed":
310
+ warnings.warn(
311
+ "participant_ids provided but mode='fixed'. "
312
+ "Participant IDs will be ignored.",
313
+ UserWarning,
314
+ stacklevel=3,
315
+ )
316
+ return ["_fixed_"] * len(items)
317
+
318
+ return participant_ids
319
+
320
+ @property
321
+ @abstractmethod
322
+ def supported_task_types(self) -> list[TaskType]:
323
+ """Get list of task types this model supports.
324
+
325
+ Returns
326
+ -------
327
+ list[TaskType]
328
+ List of supported TaskType literals from items.models.
329
+
330
+ Examples
331
+ --------
332
+ >>> model.supported_task_types
333
+ ['forced_choice']
334
+ """
335
+ pass
336
+
337
+ @abstractmethod
338
+ def validate_item_compatibility(
339
+ self, item: Item, item_template: ItemTemplate
340
+ ) -> None:
341
+ """Validate that an item is compatible with this model.
342
+
343
+ Parameters
344
+ ----------
345
+ item : Item
346
+ Item to validate.
347
+ item_template : ItemTemplate
348
+ Template the item was constructed from.
349
+
350
+ Raises
351
+ ------
352
+ ValueError
353
+ If item's task_type is not in supported_task_types.
354
+ ValueError
355
+ If item is missing required elements.
356
+ ValueError
357
+ If item structure is incompatible with model.
358
+
359
+ Examples
360
+ --------
361
+ >>> model.validate_item_compatibility(item, template) # doctest: +SKIP
362
+ """
363
+ pass
364
+
365
+ # Hook methods for model-specific implementations
366
+ @abstractmethod
367
+ def _prepare_training_data(
368
+ self,
369
+ items: list[Item],
370
+ labels: list[str],
371
+ participant_ids: list[str],
372
+ validation_items: list[Item] | None,
373
+ validation_labels: list[str] | None,
374
+ ) -> tuple[list[Item], list, list[str], list[Item] | None, list | None]:
375
+ """Prepare training data for model-specific training.
376
+
377
+ Parameters
378
+ ----------
379
+ items : list[Item]
380
+ Training items.
381
+ labels : list[str]
382
+ Training labels.
383
+ participant_ids : list[str]
384
+ Normalized participant IDs.
385
+ validation_items : list[Item] | None
386
+ Validation items.
387
+ validation_labels : list[str] | None
388
+ Validation labels.
389
+
390
+ Returns
391
+ -------
392
+ tuple[list[Item], list, list[str], list[Item] | None, list | None]
393
+ Items, labels, participant_ids, val_items, val_labels.
394
+ """
395
+ pass
396
+
397
+ @abstractmethod
398
+ def _initialize_random_effects(self, n_classes: int) -> None:
399
+ """Initialize random effects manager.
400
+
401
+ Parameters
402
+ ----------
403
+ n_classes : int
404
+ Number of classes for random effects.
405
+ """
406
+ pass
407
+
408
+ @abstractmethod
409
+ def _do_training(
410
+ self,
411
+ items: list[Item],
412
+ labels_numeric: list,
413
+ participant_ids: list[str],
414
+ validation_items: list[Item] | None,
415
+ validation_labels_numeric: list | None,
416
+ ) -> dict[str, float]:
417
+ """Perform model-specific training.
418
+
419
+ Parameters
420
+ ----------
421
+ items : list[Item]
422
+ Training items.
423
+ labels_numeric : list
424
+ Numeric labels (format depends on model).
425
+ participant_ids : list[str]
426
+ Participant IDs.
427
+ validation_items : list[Item] | None
428
+ Validation items.
429
+ validation_labels_numeric : list | None
430
+ Numeric validation labels.
431
+
432
+ Returns
433
+ -------
434
+ dict[str, float]
435
+ Training metrics.
436
+ """
437
+ pass
438
+
439
+ @abstractmethod
440
+ def _do_predict(
441
+ self, items: list[Item], participant_ids: list[str]
442
+ ) -> list[ModelPrediction]:
443
+ """Perform model-specific prediction.
444
+
445
+ Parameters
446
+ ----------
447
+ items : list[Item]
448
+ Items to predict.
449
+ participant_ids : list[str]
450
+ Normalized participant IDs.
451
+
452
+ Returns
453
+ -------
454
+ list[ModelPrediction]
455
+ Predictions.
456
+ """
457
+ pass
458
+
459
+ @abstractmethod
460
+ def _do_predict_proba(
461
+ self, items: list[Item], participant_ids: list[str]
462
+ ) -> np.ndarray:
463
+ """Perform model-specific probability prediction.
464
+
465
+ Parameters
466
+ ----------
467
+ items : list[Item]
468
+ Items to predict.
469
+ participant_ids : list[str]
470
+ Normalized participant IDs.
471
+
472
+ Returns
473
+ -------
474
+ np.ndarray
475
+ Probability array.
476
+ """
477
+ pass
478
+
479
+ @abstractmethod
480
+ def _get_save_state(self) -> dict[str, object]:
481
+ """Get model-specific state to save.
482
+
483
+ Returns
484
+ -------
485
+ dict[str, object]
486
+ State dictionary to include in config.json.
487
+ """
488
+ pass
489
+
490
+ @abstractmethod
491
+ def _save_model_components(self, save_path: Path) -> None:
492
+ """Save model-specific components (encoder, head, etc.).
493
+
494
+ Parameters
495
+ ----------
496
+ save_path : Path
497
+ Directory to save to.
498
+ """
499
+ pass
500
+
501
+ @abstractmethod
502
+ def _load_model_components(self, load_path: Path) -> None:
503
+ """Load model-specific components.
504
+
505
+ Parameters
506
+ ----------
507
+ load_path : Path
508
+ Directory to load from.
509
+ """
510
+ pass
511
+
512
+ @abstractmethod
513
+ def _restore_training_state(self, config_dict: dict[str, object]) -> None:
514
+ """Restore model-specific training state.
515
+
516
+ Parameters
517
+ ----------
518
+ config_dict : dict[str, object]
519
+ Configuration dictionary with training state.
520
+ """
521
+ pass
522
+
523
+ @abstractmethod
524
+ def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
525
+ """Get fixed head for random effects loading.
526
+
527
+ Returns
528
+ -------
529
+ nn.Module | None
530
+ Fixed head module, or None if not applicable.
531
+ """
532
+ pass
533
+
534
+ @abstractmethod
535
+ def _get_n_classes_for_random_effects(self) -> int:
536
+ """Get number of classes for random effects initialization.
537
+
538
+ Returns
539
+ -------
540
+ int
541
+ Number of classes.
542
+ """
543
+ pass
544
+
545
+ # Common implementations
546
+ def train(
547
+ self,
548
+ items: list[Item],
549
+ labels: list[str] | list[list[str]],
550
+ participant_ids: list[str] | None = None,
551
+ validation_items: list[Item] | None = None,
552
+ validation_labels: list[str] | list[list[str]] | None = None,
553
+ ) -> dict[str, float]:
554
+ """Train model on labeled items with participant identifiers.
555
+
556
+ Parameters
557
+ ----------
558
+ items : list[Item]
559
+ Training items.
560
+ labels : list[str]
561
+ Training labels (format depends on task type).
562
+ participant_ids : list[str] | None
563
+ Participant identifier for each item.
564
+ - For fixed effects (mode='fixed'): Pass None (automatically handled).
565
+ - For mixed effects (mode='random_intercepts' or 'random_slopes'):
566
+ Must provide list[str] with same length as items.
567
+ Must not contain empty strings.
568
+ validation_items : list[Item] | None
569
+ Optional validation items.
570
+ validation_labels : list[str] | None
571
+ Optional validation labels.
572
+
573
+ Returns
574
+ -------
575
+ dict[str, float]
576
+ Training metrics including:
577
+ - "train_accuracy", "train_loss": Standard metrics
578
+ - "participant_variance": σ²_u (if estimate_variance_components=True)
579
+ - "n_participants": Number of unique participants
580
+ - "residual_variance": σ²_ε (if estimated)
581
+
582
+ Raises
583
+ ------
584
+ ValueError
585
+ If participant_ids is None when mode is 'random_intercepts'
586
+ or 'random_slopes'.
587
+ ValueError
588
+ If items, labels, and participant_ids have different lengths.
589
+ ValueError
590
+ If participant_ids contains empty strings.
591
+ ValueError
592
+ If validation data is incomplete.
593
+ ValueError
594
+ If labels are invalid for this task type.
595
+ """
596
+ # Validate input lengths (handle both list[str] and list[list[str]] labels)
597
+ if labels and isinstance(labels[0], list):
598
+ # Cloze model: labels is list[list[str]]
599
+ if len(items) != len(labels):
600
+ raise ValueError(
601
+ f"Number of items ({len(items)}) must match "
602
+ f"number of labels ({len(labels)})"
603
+ )
604
+ else:
605
+ # Standard models: labels is list[str]
606
+ self._validate_items_labels_length(items, labels)
607
+
608
+ # Validate and normalize participant_ids
609
+ participant_ids = self._normalize_participant_ids(
610
+ participant_ids, items, self.config.mixed_effects.mode
611
+ )
612
+
613
+ if (validation_items is None) != (validation_labels is None):
614
+ raise ValueError(
615
+ "Both validation_items and validation_labels must be "
616
+ "provided, or neither"
617
+ )
618
+
619
+ # Prepare training data (model-specific)
620
+ (
621
+ prepared_items,
622
+ labels_numeric,
623
+ participant_ids,
624
+ validation_items,
625
+ validation_labels_numeric,
626
+ ) = self._prepare_training_data(
627
+ items, labels, participant_ids, validation_items, validation_labels
628
+ )
629
+
630
+ # Initialize random effects
631
+ n_classes = self._get_n_classes_for_random_effects()
632
+ self._initialize_random_effects(n_classes)
633
+
634
+ # Register participants for adaptive regularization
635
+ if hasattr(self, "random_effects") and self.random_effects is not None:
636
+ participant_counts = Counter(participant_ids)
637
+ for pid, count in participant_counts.items():
638
+ self.random_effects.register_participant(pid, count)
639
+
640
+ # Perform training (model-specific)
641
+ metrics = self._do_training(
642
+ prepared_items,
643
+ labels_numeric,
644
+ participant_ids,
645
+ validation_items,
646
+ validation_labels_numeric,
647
+ )
648
+
649
+ self._is_fitted = True
650
+
651
+ # Estimate variance components
652
+ if (
653
+ self.config.mixed_effects.estimate_variance_components
654
+ and hasattr(self, "random_effects")
655
+ and self.random_effects is not None
656
+ ):
657
+ var_comps = self.random_effects.estimate_variance_components()
658
+ if var_comps:
659
+ var_comp = var_comps.get("mu") or var_comps.get("slopes")
660
+ if var_comp:
661
+ if not hasattr(self, "variance_history"):
662
+ self.variance_history = []
663
+ self.variance_history.append(var_comp)
664
+ metrics["participant_variance"] = var_comp.variance
665
+ metrics["n_participants"] = var_comp.n_groups
666
+
667
+ return metrics
668
+
669
+ def predict(
670
+ self, items: list[Item], participant_ids: list[str] | None = None
671
+ ) -> list[ModelPrediction]:
672
+ """Predict class labels for items with participant identifiers.
673
+
674
+ Parameters
675
+ ----------
676
+ items : list[Item]
677
+ Items to predict.
678
+ participant_ids : list[str] | None
679
+ Participant identifier for each item.
680
+ - For fixed effects (mode='fixed'): Pass None.
681
+ - For mixed effects: Must provide list[str] with same length as items.
682
+ - For unknown participants: Use population mean (prior) for random effects.
683
+
684
+ Returns
685
+ -------
686
+ list[ModelPrediction]
687
+ Predictions with probabilities and predicted class for each item.
688
+
689
+ Raises
690
+ ------
691
+ ValueError
692
+ If model has not been trained.
693
+ ValueError
694
+ If participant_ids is None when mode requires mixed effects.
695
+ ValueError
696
+ If items and participant_ids have different lengths.
697
+ ValueError
698
+ If participant_ids contains empty strings.
699
+ ValueError
700
+ If items are incompatible with model.
701
+ """
702
+ if not self._is_fitted:
703
+ raise ValueError("Model not trained. Call train() before predict().")
704
+
705
+ # Validate and normalize participant_ids
706
+ participant_ids = self._normalize_participant_ids(
707
+ participant_ids, items, self.config.mixed_effects.mode
708
+ )
709
+
710
+ return self._do_predict(items, participant_ids)
711
+
712
+ def predict_proba(
713
+ self, items: list[Item], participant_ids: list[str] | None = None
714
+ ) -> np.ndarray:
715
+ """Predict class probabilities for items with participant identifiers.
716
+
717
+ Parameters
718
+ ----------
719
+ items : list[Item]
720
+ Items to predict.
721
+ participant_ids : list[str] | None
722
+ Participant identifier for each item.
723
+ - For fixed effects (mode='fixed'): Pass None.
724
+ - For mixed effects: Must provide list[str] with same length as items.
725
+
726
+ Returns
727
+ -------
728
+ np.ndarray
729
+ Array of shape (n_items, n_classes) with probabilities.
730
+ Each row sums to 1.0 for classification tasks.
731
+
732
+ Raises
733
+ ------
734
+ ValueError
735
+ If model has not been trained.
736
+ ValueError
737
+ If participant_ids is None when mode requires mixed effects.
738
+ ValueError
739
+ If items and participant_ids have different lengths.
740
+ ValueError
741
+ If participant_ids contains empty strings.
742
+ ValueError
743
+ If items are incompatible with model.
744
+ """
745
+ if not self._is_fitted:
746
+ raise ValueError("Model not trained. Call train() before predict_proba().")
747
+
748
+ # Validate and normalize participant_ids
749
+ participant_ids = self._normalize_participant_ids(
750
+ participant_ids, items, self.config.mixed_effects.mode
751
+ )
752
+
753
+ return self._do_predict_proba(items, participant_ids)
754
+
755
+ def save(self, path: str) -> None:
756
+ """Save model to disk.
757
+
758
+ Parameters
759
+ ----------
760
+ path : str
761
+ File or directory path to save the model.
762
+
763
+ Raises
764
+ ------
765
+ ValueError
766
+ If model has not been trained.
767
+ """
768
+ if not self._is_fitted:
769
+ raise ValueError("Model not trained. Call train() before save().")
770
+
771
+ save_path = Path(path)
772
+ save_path.mkdir(parents=True, exist_ok=True)
773
+
774
+ # Save model-specific components
775
+ self._save_model_components(save_path)
776
+
777
+ # Save random effects (includes variance history)
778
+ if hasattr(self, "random_effects") and self.random_effects is not None:
779
+ # Copy variance_history from model to random_effects before saving
780
+ if hasattr(self, "variance_history"):
781
+ self.random_effects.variance_history = self.variance_history.copy()
782
+ self.random_effects.save(save_path / "random_effects")
783
+
784
+ # Save config with model-specific state
785
+ config_dict = self.config.model_dump()
786
+ save_state = self._get_save_state()
787
+ config_dict.update(save_state)
788
+
789
+ with open(save_path / "config.json", "w") as f:
790
+ json.dump(config_dict, f, indent=2)
791
+
792
+ def load(self, path: str) -> None:
793
+ """Load model from disk.
794
+
795
+ Parameters
796
+ ----------
797
+ path : str
798
+ File or directory path to load the model from.
799
+
800
+ Raises
801
+ ------
802
+ FileNotFoundError
803
+ If model file/directory does not exist.
804
+ """
805
+ load_path = Path(path)
806
+ if not load_path.exists():
807
+ raise FileNotFoundError(f"Model directory not found: {path}")
808
+
809
+ with open(load_path / "config.json") as f:
810
+ config_dict = json.load(f)
811
+
812
+ # Restore model-specific training state (before reconstructing config)
813
+ self._restore_training_state(config_dict)
814
+
815
+ # Load model-specific components (which will reconstruct the config)
816
+ # This must happen before initializing random effects so config is correct
817
+ self._load_model_components(load_path)
818
+
819
+ # Initialize and load random effects
820
+ n_classes = self._get_n_classes_for_random_effects()
821
+ from bead.active_learning.models.random_effects import ( # noqa: PLC0415
822
+ RandomEffectsManager,
823
+ )
824
+
825
+ # Check if model uses vocab_size instead of n_classes (e.g., ClozeModel)
826
+ if hasattr(self, "tokenizer") and hasattr(self.tokenizer, "vocab_size"):
827
+ # ClozeModel: use vocab_size
828
+ self.random_effects = RandomEffectsManager(
829
+ self.config.mixed_effects, vocab_size=n_classes
830
+ )
831
+ else:
832
+ # Standard models: use n_classes
833
+ self.random_effects = RandomEffectsManager(
834
+ self.config.mixed_effects, n_classes=n_classes
835
+ )
836
+ random_effects_path = load_path / "random_effects"
837
+ if random_effects_path.exists():
838
+ fixed_head = self._get_random_effects_fixed_head()
839
+ self.random_effects.load(random_effects_path, fixed_head=fixed_head)
840
+ # Restore variance history from random_effects
841
+ if hasattr(self.random_effects, "variance_history"):
842
+ if not hasattr(self, "variance_history"):
843
+ self.variance_history = []
844
+ self.variance_history = self.random_effects.variance_history.copy()
845
+
846
+ # Move to device (model-specific)
847
+ if hasattr(self, "encoder"):
848
+ self.encoder.to(self.config.device)
849
+ if hasattr(self, "model"):
850
+ self.model.to(self.config.device)
851
+
852
+ self._is_fitted = True