bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,835 @@
1
+ """Magnitude model for unbounded and bounded numeric judgments.
2
+
3
+ Implements continuous regression with support for:
4
+ - Unbounded values: Normal distribution N(μ, σ²)
5
+ - Bounded values: Truncated Normal distribution N(μ, σ) T[min, max]
6
+ Supports GLMM with participant-level random effects (intercepts and slopes).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import tempfile
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.distributions import Normal
19
+ from transformers import AutoModel, AutoTokenizer, TrainingArguments
20
+
21
+ from bead.active_learning.config import MixedEffectsConfig, VarianceComponents
22
+ from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
23
+ from bead.active_learning.models.random_effects import RandomEffectsManager
24
+ from bead.active_learning.trainers.data_collator import MixedEffectsDataCollator
25
+ from bead.active_learning.trainers.dataset_utils import items_to_dataset
26
+ from bead.active_learning.trainers.metrics import compute_regression_metrics
27
+ from bead.active_learning.trainers.model_wrapper import EncoderRegressionWrapper
28
+ from bead.config.active_learning import MagnitudeModelConfig
29
+ from bead.items.item import Item
30
+ from bead.items.item_template import ItemTemplate, TaskType
31
+
32
+ __all__ = ["MagnitudeModel"]
33
+
34
+
35
+ class MagnitudeModel(ActiveLearningModel):
36
+ """Model for magnitude tasks with unbounded or bounded continuous responses.
37
+
38
+ Uses Normal distribution for unbounded values (e.g., reading time, plausibility)
39
+ or Truncated Normal for bounded values (e.g., confidence on 0-100 scale).
40
+ Supports three modes: fixed effects, random intercepts, random slopes.
41
+
42
+ Parameters
43
+ ----------
44
+ config : MagnitudeModelConfig
45
+ Configuration object containing all model parameters.
46
+
47
+ Attributes
48
+ ----------
49
+ config : MagnitudeModelConfig
50
+ Model configuration.
51
+ tokenizer : AutoTokenizer
52
+ Transformer tokenizer.
53
+ encoder : AutoModel
54
+ Transformer encoder model.
55
+ regression_head : nn.Sequential
56
+ Regression head (fixed effects head) - outputs continuous μ.
57
+ random_effects : RandomEffectsManager
58
+ Manager for participant-level random effects.
59
+ variance_history : list[VarianceComponents]
60
+ Variance component estimates over training (for diagnostics).
61
+ _is_fitted : bool
62
+ Whether model has been trained.
63
+
64
+ Examples
65
+ --------
66
+ >>> from uuid import uuid4
67
+ >>> from bead.items.item import Item
68
+ >>> from bead.config.active_learning import MagnitudeModelConfig
69
+ >>> items = [
70
+ ... Item(
71
+ ... item_template_id=uuid4(),
72
+ ... rendered_elements={"text": f"Sentence {i}"}
73
+ ... )
74
+ ... for i in range(10)
75
+ ... ]
76
+ >>> labels = ["250.5", "300.2"] * 5 # Reading times in ms
77
+ >>> config = MagnitudeModelConfig( # doctest: +SKIP
78
+ ... num_epochs=1, batch_size=2, device="cpu"
79
+ ... )
80
+ >>> model = MagnitudeModel(config=config) # doctest: +SKIP
81
+ >>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
82
+ >>> predictions = model.predict(items[:3], participant_ids=None) # doctest: +SKIP
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ config: MagnitudeModelConfig | None = None,
88
+ ) -> None:
89
+ """Initialize magnitude model.
90
+
91
+ Parameters
92
+ ----------
93
+ config : MagnitudeModelConfig | None
94
+ Configuration object. If None, uses default configuration.
95
+ """
96
+ self.config = config or MagnitudeModelConfig()
97
+
98
+ # Validate mixed_effects configuration
99
+ super().__init__(self.config)
100
+
101
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
102
+ self.encoder = AutoModel.from_pretrained(self.config.model_name)
103
+
104
+ self.regression_head: nn.Sequential | None = None
105
+ self._is_fitted = False
106
+
107
+ # Initialize random effects manager
108
+ self.random_effects: RandomEffectsManager | None = None
109
+ self.variance_history: list[VarianceComponents] = []
110
+
111
+ self.encoder.to(self.config.device)
112
+
113
+ @property
114
+ def supported_task_types(self) -> list[TaskType]:
115
+ """Get supported task types.
116
+
117
+ Returns
118
+ -------
119
+ list[TaskType]
120
+ List containing "magnitude".
121
+ """
122
+ return ["magnitude"]
123
+
124
+ def validate_item_compatibility(
125
+ self, item: Item, item_template: ItemTemplate
126
+ ) -> None:
127
+ """Validate item is compatible with magnitude model.
128
+
129
+ Parameters
130
+ ----------
131
+ item : Item
132
+ Item to validate.
133
+ item_template : ItemTemplate
134
+ Template the item was constructed from.
135
+
136
+ Raises
137
+ ------
138
+ ValueError
139
+ If task_type is not "magnitude".
140
+ """
141
+ if item_template.task_type != "magnitude":
142
+ raise ValueError(
143
+ f"Expected task_type 'magnitude', got '{item_template.task_type}'"
144
+ )
145
+
146
+ def _initialize_regression_head(self) -> None:
147
+ """Initialize regression head for continuous output μ."""
148
+ hidden_size = self.encoder.config.hidden_size
149
+
150
+ # Single output for location parameter μ
151
+ self.regression_head = nn.Sequential(
152
+ nn.Linear(hidden_size, 256),
153
+ nn.ReLU(),
154
+ nn.Dropout(0.1),
155
+ nn.Linear(256, 1), # Output μ (location parameter)
156
+ )
157
+ self.regression_head.to(self.config.device)
158
+
159
+ def _encode_texts(self, texts: list[str]) -> torch.Tensor:
160
+ """Encode texts using transformer.
161
+
162
+ Parameters
163
+ ----------
164
+ texts : list[str]
165
+ Texts to encode.
166
+
167
+ Returns
168
+ -------
169
+ torch.Tensor
170
+ Encoded representations of shape (batch_size, hidden_size).
171
+ """
172
+ encodings = self.tokenizer(
173
+ texts,
174
+ padding=True,
175
+ truncation=True,
176
+ max_length=self.config.max_length,
177
+ return_tensors="pt",
178
+ )
179
+ encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
180
+
181
+ outputs = self.encoder(**encodings)
182
+ return outputs.last_hidden_state[:, 0, :]
183
+
184
+ def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
185
+ """Prepare inputs for encoding.
186
+
187
+ For magnitude tasks, concatenates all rendered elements.
188
+
189
+ Parameters
190
+ ----------
191
+ items : list[Item]
192
+ Items to encode.
193
+
194
+ Returns
195
+ -------
196
+ torch.Tensor
197
+ Encoded representations.
198
+ """
199
+ texts = []
200
+ for item in items:
201
+ # Concatenate all rendered elements
202
+ all_text = " ".join(item.rendered_elements.values())
203
+ texts.append(all_text)
204
+ return self._encode_texts(texts)
205
+
206
+ def _normal_log_prob(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
207
+ """Compute log probability of normal distribution (unbounded case).
208
+
209
+ Parameters
210
+ ----------
211
+ y : torch.Tensor
212
+ Observed values, shape (batch,).
213
+ mu : torch.Tensor
214
+ Location parameters, shape (batch,).
215
+
216
+ Returns
217
+ -------
218
+ torch.Tensor
219
+ Log probabilities, shape (batch,).
220
+ """
221
+ dist = Normal(mu.squeeze(), self.config.sigma)
222
+ return dist.log_prob(y)
223
+
224
+ def _truncated_normal_log_prob(
225
+ self, y: torch.Tensor, mu: torch.Tensor
226
+ ) -> torch.Tensor:
227
+ """Compute log probability of truncated normal distribution (bounded case).
228
+
229
+ Uses truncated normal on [min_value, max_value] to properly handle
230
+ bounded responses without arbitrary nudging.
231
+
232
+ Parameters
233
+ ----------
234
+ y : torch.Tensor
235
+ Observed values, shape (batch,).
236
+ mu : torch.Tensor
237
+ Location parameters (before truncation), shape (batch,).
238
+
239
+ Returns
240
+ -------
241
+ torch.Tensor
242
+ Log probabilities, shape (batch,).
243
+ """
244
+ base_dist = Normal(mu.squeeze(), self.config.sigma)
245
+
246
+ # Unnormalized log prob
247
+ log_prob_unnorm = base_dist.log_prob(y)
248
+
249
+ # Normalizer: log(Φ((high-μ)/σ) - Φ((low-μ)/σ))
250
+ alpha = (self.config.min_value - mu.squeeze()) / self.config.sigma
251
+ beta = (self.config.max_value - mu.squeeze()) / self.config.sigma
252
+ normalizer = base_dist.cdf(beta) - base_dist.cdf(alpha)
253
+
254
+ # Clamp to avoid log(0)
255
+ normalizer = torch.clamp(normalizer, min=1e-8)
256
+ log_normalizer = torch.log(normalizer)
257
+
258
+ return log_prob_unnorm - log_normalizer
259
+
260
+ def _prepare_training_data(
261
+ self,
262
+ items: list[Item],
263
+ labels: list[str],
264
+ participant_ids: list[str],
265
+ validation_items: list[Item] | None,
266
+ validation_labels: list[str] | None,
267
+ ) -> tuple[
268
+ list[Item], list[float], list[str], list[Item] | None, list[float] | None
269
+ ]:
270
+ """Prepare training data for magnitude model.
271
+
272
+ Parameters
273
+ ----------
274
+ items : list[Item]
275
+ Training items.
276
+ labels : list[str]
277
+ Training labels (continuous values as strings).
278
+ participant_ids : list[str]
279
+ Normalized participant IDs.
280
+ validation_items : list[Item] | None
281
+ Validation items.
282
+ validation_labels : list[str] | None
283
+ Validation labels.
284
+
285
+ Returns
286
+ -------
287
+ tuple[list[Item], list[float], list[str], list[Item] | None, list[float] | None]
288
+ Prepared items, numeric labels (floats), participant_ids,
289
+ validation_items, numeric validation_labels.
290
+ """
291
+ # Parse labels to floats
292
+ try:
293
+ y_values = [float(label) for label in labels]
294
+ except ValueError as e:
295
+ raise ValueError(
296
+ f"Labels must be numeric strings (e.g., '250.5', '300.2'). "
297
+ f"Got error: {e}"
298
+ ) from e
299
+
300
+ # Validate bounds for bounded case
301
+ if self.config.bounded:
302
+ for i, val in enumerate(y_values):
303
+ if not (self.config.min_value <= val <= self.config.max_value):
304
+ raise ValueError(
305
+ f"Label at index {i} ({val}) is outside bounds "
306
+ f"[{self.config.min_value}, {self.config.max_value}]"
307
+ )
308
+
309
+ self._initialize_regression_head()
310
+
311
+ # Convert validation labels if provided
312
+ val_y_numeric = None
313
+ if validation_items is not None and validation_labels is not None:
314
+ try:
315
+ val_y_numeric = [float(label) for label in validation_labels]
316
+ except ValueError as e:
317
+ raise ValueError(
318
+ f"Validation labels must be numeric strings. Got error: {e}"
319
+ ) from e
320
+
321
+ # Validate bounds for validation labels
322
+ if self.config.bounded:
323
+ for i, val in enumerate(val_y_numeric):
324
+ if not (self.config.min_value <= val <= self.config.max_value):
325
+ raise ValueError(
326
+ f"Validation label at index {i} ({val}) is outside bounds "
327
+ f"[{self.config.min_value}, {self.config.max_value}]"
328
+ )
329
+
330
+ return items, y_values, participant_ids, validation_items, val_y_numeric
331
+
332
+ def _initialize_random_effects(self, n_classes: int) -> None:
333
+ """Initialize random effects manager.
334
+
335
+ Parameters
336
+ ----------
337
+ n_classes : int
338
+ Number of classes (1 for regression).
339
+ """
340
+ self.random_effects = RandomEffectsManager(
341
+ self.config.mixed_effects,
342
+ n_classes=n_classes, # Scalar bias for μ
343
+ )
344
+
345
+ def _do_training(
346
+ self,
347
+ items: list[Item],
348
+ labels_numeric: list[float],
349
+ participant_ids: list[str],
350
+ validation_items: list[Item] | None,
351
+ validation_labels_numeric: list[float] | None,
352
+ ) -> dict[str, float]:
353
+ """Perform magnitude model training.
354
+
355
+ Parameters
356
+ ----------
357
+ items : list[Item]
358
+ Training items.
359
+ labels_numeric : list[float]
360
+ Numeric labels (continuous values).
361
+ participant_ids : list[str]
362
+ Participant IDs.
363
+ validation_items : list[Item] | None
364
+ Validation items.
365
+ validation_labels_numeric : list[float] | None
366
+ Numeric validation labels.
367
+
368
+ Returns
369
+ -------
370
+ dict[str, float]
371
+ Training metrics.
372
+ """
373
+ # Convert validation_labels_numeric back to string labels for validation metrics
374
+ validation_labels = None
375
+ if validation_items is not None and validation_labels_numeric is not None:
376
+ validation_labels = [str(val) for val in validation_labels_numeric]
377
+
378
+ # Use HuggingFace Trainer for fixed and random_intercepts modes
379
+ # random_slopes requires custom loop due to per-participant heads
380
+ use_huggingface_trainer = self.config.mixed_effects.mode in (
381
+ "fixed",
382
+ "random_intercepts",
383
+ )
384
+
385
+ if use_huggingface_trainer:
386
+ metrics = self._train_with_huggingface_trainer(
387
+ items,
388
+ labels_numeric,
389
+ participant_ids,
390
+ validation_items,
391
+ validation_labels,
392
+ )
393
+ else:
394
+ # Use custom training loop for random_slopes
395
+ metrics = self._train_with_custom_loop(
396
+ items,
397
+ labels_numeric,
398
+ participant_ids,
399
+ validation_items,
400
+ validation_labels,
401
+ )
402
+
403
+ # Add validation MSE if validation data provided and not already computed
404
+ if (
405
+ validation_items is not None
406
+ and validation_labels is not None
407
+ and "val_mse" not in metrics
408
+ ):
409
+ # Validation with placeholder participant_ids for mixed effects
410
+ if self.config.mixed_effects.mode == "fixed":
411
+ val_participant_ids = ["_fixed_"] * len(validation_items)
412
+ else:
413
+ val_participant_ids = ["_validation_"] * len(validation_items)
414
+ val_predictions = self._do_predict(validation_items, val_participant_ids)
415
+ val_pred_values = [float(p.predicted_class) for p in val_predictions]
416
+ val_true_values = [float(label) for label in validation_labels]
417
+ val_mse = np.mean(
418
+ [
419
+ (pred - true) ** 2
420
+ for pred, true in zip(val_pred_values, val_true_values, strict=True)
421
+ ]
422
+ )
423
+ metrics["val_mse"] = val_mse
424
+
425
+ return metrics
426
+
427
+ def _train_with_huggingface_trainer(
428
+ self,
429
+ items: list[Item],
430
+ y_numeric: list[float],
431
+ participant_ids: list[str],
432
+ validation_items: list[Item] | None,
433
+ validation_labels: list[str] | None,
434
+ ) -> dict[str, float]:
435
+ """Train using HuggingFace Trainer with mixed effects support for regression.
436
+
437
+ Parameters
438
+ ----------
439
+ items : list[Item]
440
+ Training items.
441
+ y_numeric : list[float]
442
+ Numeric labels (continuous values).
443
+ participant_ids : list[str]
444
+ Participant IDs.
445
+ validation_items : list[Item] | None
446
+ Validation items.
447
+ validation_labels : list[str] | None
448
+ Validation labels.
449
+
450
+ Returns
451
+ -------
452
+ dict[str, float]
453
+ Training metrics.
454
+ """
455
+ # Convert items to HuggingFace Dataset
456
+ train_dataset = items_to_dataset(
457
+ items=items,
458
+ labels=y_numeric,
459
+ participant_ids=participant_ids,
460
+ tokenizer=self.tokenizer,
461
+ max_length=self.config.max_length,
462
+ )
463
+
464
+ eval_dataset = None
465
+ if validation_items is not None and validation_labels is not None:
466
+ val_y_numeric = [float(label) for label in validation_labels]
467
+ val_participant_ids = (
468
+ ["_validation_"] * len(validation_items)
469
+ if self.config.mixed_effects.mode != "fixed"
470
+ else ["_fixed_"] * len(validation_items)
471
+ )
472
+ eval_dataset = items_to_dataset(
473
+ items=validation_items,
474
+ labels=val_y_numeric,
475
+ participant_ids=val_participant_ids,
476
+ tokenizer=self.tokenizer,
477
+ max_length=self.config.max_length,
478
+ )
479
+
480
+ # Wrap the encoder and regression head for Trainer
481
+ wrapped_model = EncoderRegressionWrapper(
482
+ encoder=self.encoder, regression_head=self.regression_head
483
+ )
484
+
485
+ # Create data collator
486
+ data_collator = MixedEffectsDataCollator(tokenizer=self.tokenizer)
487
+
488
+ # Create training arguments with checkpointing
489
+ with tempfile.TemporaryDirectory() as tmpdir:
490
+ checkpoint_dir = Path(tmpdir) / "checkpoints"
491
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
492
+
493
+ training_args = TrainingArguments(
494
+ output_dir=str(checkpoint_dir),
495
+ num_train_epochs=self.config.num_epochs,
496
+ per_device_train_batch_size=self.config.batch_size,
497
+ per_device_eval_batch_size=self.config.batch_size,
498
+ learning_rate=self.config.learning_rate,
499
+ logging_steps=10,
500
+ eval_strategy="epoch" if eval_dataset is not None else "no",
501
+ save_strategy="epoch",
502
+ save_total_limit=1,
503
+ load_best_model_at_end=False,
504
+ report_to="none",
505
+ remove_unused_columns=False,
506
+ use_cpu=self.config.device == "cpu",
507
+ )
508
+
509
+ # Import here to avoid circular import
510
+ from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
511
+ MixedEffectsTrainer,
512
+ )
513
+
514
+ # Create trainer
515
+ trainer = MixedEffectsTrainer(
516
+ model=wrapped_model,
517
+ args=training_args,
518
+ train_dataset=train_dataset,
519
+ eval_dataset=eval_dataset,
520
+ data_collator=data_collator,
521
+ tokenizer=self.tokenizer,
522
+ random_effects_manager=self.random_effects,
523
+ compute_metrics=compute_regression_metrics,
524
+ )
525
+
526
+ # Train
527
+ train_result = trainer.train()
528
+
529
+ # Get training metrics
530
+ train_metrics = trainer.evaluate(eval_dataset=train_dataset)
531
+ metrics: dict[str, float] = {
532
+ "train_loss": float(train_result.training_loss),
533
+ "train_mse": train_metrics.get("eval_mse", 0.0),
534
+ "train_mae": train_metrics.get("eval_mae", 0.0),
535
+ "train_r2": train_metrics.get("eval_r2", 0.0),
536
+ }
537
+
538
+ # Get validation metrics if eval_dataset was provided
539
+ if eval_dataset is not None:
540
+ val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
541
+ metrics.update(
542
+ {
543
+ "val_mse": val_metrics.get("eval_mse", 0.0),
544
+ "val_mae": val_metrics.get("eval_mae", 0.0),
545
+ "val_r2": val_metrics.get("eval_r2", 0.0),
546
+ }
547
+ )
548
+
549
+ return metrics
550
+
551
+ def _train_with_custom_loop(
552
+ self,
553
+ items: list[Item],
554
+ y_numeric: list[float],
555
+ participant_ids: list[str],
556
+ validation_items: list[Item] | None,
557
+ validation_labels: list[str] | None,
558
+ ) -> dict[str, float]:
559
+ """Train using custom loop for random_slopes mode.
560
+
561
+ Parameters
562
+ ----------
563
+ items : list[Item]
564
+ Training items.
565
+ y_numeric : list[float]
566
+ Numeric labels (continuous values).
567
+ participant_ids : list[str]
568
+ Participant IDs.
569
+ validation_items : list[Item] | None
570
+ Validation items.
571
+ validation_labels : list[str] | None
572
+ Validation labels.
573
+
574
+ Returns
575
+ -------
576
+ dict[str, float]
577
+ Training metrics.
578
+ """
579
+ y = torch.tensor(y_numeric, dtype=torch.float, device=self.config.device)
580
+
581
+ # Build optimizer parameters
582
+ params_to_optimize = list(self.encoder.parameters()) + list(
583
+ self.regression_head.parameters()
584
+ )
585
+
586
+ # Add random effects parameters for random_slopes
587
+ for head in self.random_effects.slopes.values():
588
+ params_to_optimize.extend(head.parameters())
589
+
590
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
591
+
592
+ self.encoder.train()
593
+ self.regression_head.train()
594
+
595
+ for _epoch in range(self.config.num_epochs):
596
+ n_batches = (
597
+ len(items) + self.config.batch_size - 1
598
+ ) // self.config.batch_size
599
+ epoch_loss = 0.0
600
+ epoch_mse = 0.0
601
+
602
+ for i in range(n_batches):
603
+ start_idx = i * self.config.batch_size
604
+ end_idx = min(start_idx + self.config.batch_size, len(items))
605
+
606
+ batch_items = items[start_idx:end_idx]
607
+ batch_labels = y[start_idx:end_idx]
608
+ batch_participant_ids = participant_ids[start_idx:end_idx]
609
+
610
+ embeddings = self._prepare_inputs(batch_items)
611
+
612
+ # Per-participant head for random_slopes
613
+ mu_list = []
614
+ for j, pid in enumerate(batch_participant_ids):
615
+ participant_head = self.random_effects.get_slopes(
616
+ pid,
617
+ fixed_head=self.regression_head,
618
+ create_if_missing=True,
619
+ )
620
+ mu_j = participant_head(embeddings[j : j + 1]).squeeze()
621
+ mu_list.append(mu_j)
622
+ mu = torch.stack(mu_list)
623
+
624
+ # Compute negative log-likelihood
625
+ if self.config.bounded:
626
+ log_probs = self._truncated_normal_log_prob(batch_labels, mu)
627
+ else:
628
+ log_probs = self._normal_log_prob(batch_labels, mu)
629
+ loss_nll = -log_probs.mean()
630
+
631
+ # Add prior regularization
632
+ loss_prior = self.random_effects.compute_prior_loss()
633
+ loss = loss_nll + loss_prior
634
+
635
+ optimizer.zero_grad()
636
+ loss.backward()
637
+ optimizer.step()
638
+
639
+ epoch_loss += loss.item()
640
+ # Also track MSE for interpretability
641
+ mse = ((mu - batch_labels) ** 2).mean().item()
642
+ epoch_mse += mse
643
+
644
+ epoch_loss = epoch_loss / n_batches
645
+ epoch_mse = epoch_mse / n_batches
646
+
647
+ metrics: dict[str, float] = {
648
+ "train_loss": epoch_loss,
649
+ "train_mse": epoch_mse,
650
+ }
651
+
652
+ return metrics
653
+
654
+ def _do_predict(
655
+ self, items: list[Item], participant_ids: list[str]
656
+ ) -> list[ModelPrediction]:
657
+ """Perform magnitude model prediction.
658
+
659
+ Parameters
660
+ ----------
661
+ items : list[Item]
662
+ Items to predict.
663
+ participant_ids : list[str]
664
+ Normalized participant IDs.
665
+
666
+ Returns
667
+ -------
668
+ list[ModelPrediction]
669
+ Predictions with predicted_class as string representation of value.
670
+ """
671
+ self.encoder.eval()
672
+ self.regression_head.eval()
673
+
674
+ with torch.no_grad():
675
+ embeddings = self._prepare_inputs(items)
676
+
677
+ # Forward pass depends on mixed effects mode
678
+ if self.config.mixed_effects.mode == "fixed":
679
+ mu = self.regression_head(embeddings).squeeze(1)
680
+
681
+ elif self.config.mixed_effects.mode == "random_intercepts":
682
+ mu = self.regression_head(embeddings).squeeze(1)
683
+ for i, pid in enumerate(participant_ids):
684
+ # Unknown participants: use prior mean (zero bias)
685
+ bias = self.random_effects.get_intercepts(
686
+ pid, n_classes=1, param_name="mu", create_if_missing=False
687
+ )
688
+ mu[i] = mu[i] + bias.item()
689
+
690
+ elif self.config.mixed_effects.mode == "random_slopes":
691
+ mu_list = []
692
+ for i, pid in enumerate(participant_ids):
693
+ # Unknown participants: use fixed head
694
+ participant_head = self.random_effects.get_slopes(
695
+ pid, fixed_head=self.regression_head, create_if_missing=False
696
+ )
697
+ mu_i = participant_head(embeddings[i : i + 1]).squeeze()
698
+ mu_list.append(mu_i)
699
+ mu = torch.stack(mu_list)
700
+
701
+ # Clamp predictions to bounds if bounded
702
+ if self.config.bounded:
703
+ mu = torch.clamp(mu, self.config.min_value, self.config.max_value)
704
+
705
+ predictions_array = mu.cpu().numpy()
706
+
707
+ predictions = []
708
+ for i, item in enumerate(items):
709
+ pred_value = float(predictions_array[i])
710
+ predictions.append(
711
+ ModelPrediction(
712
+ item_id=str(item.id),
713
+ probabilities={}, # Not applicable for regression
714
+ predicted_class=str(pred_value), # Continuous value as string
715
+ confidence=1.0, # Not applicable for regression
716
+ )
717
+ )
718
+
719
+ return predictions
720
+
721
+ def _do_predict_proba(
722
+ self, items: list[Item], participant_ids: list[str]
723
+ ) -> np.ndarray:
724
+ """Perform magnitude model probability prediction.
725
+
726
+ For magnitude regression, returns μ values directly.
727
+
728
+ Parameters
729
+ ----------
730
+ items : list[Item]
731
+ Items to predict.
732
+ participant_ids : list[str]
733
+ Normalized participant IDs.
734
+
735
+ Returns
736
+ -------
737
+ np.ndarray
738
+ Array of shape (n_items, 1) with predicted μ values.
739
+ """
740
+ predictions = self._do_predict(items, participant_ids)
741
+ return np.array([[float(p.predicted_class)] for p in predictions])
742
+
743
+ def _save_model_components(self, save_path: Path) -> None:
744
+ """Save model-specific components.
745
+
746
+ Parameters
747
+ ----------
748
+ save_path : Path
749
+ Directory to save to.
750
+ """
751
+ self.encoder.save_pretrained(save_path / "encoder")
752
+ self.tokenizer.save_pretrained(save_path / "encoder")
753
+
754
+ torch.save(
755
+ self.regression_head.state_dict(),
756
+ save_path / "regression_head.pt",
757
+ )
758
+
759
+ def _get_save_state(self) -> dict[str, object]:
760
+ """Get model-specific state to save.
761
+
762
+ Returns
763
+ -------
764
+ dict[str, object]
765
+ State dictionary.
766
+ """
767
+ return {}
768
+
769
+ def _restore_training_state(self, config_dict: dict[str, object]) -> None:
770
+ """Restore model-specific training state.
771
+
772
+ Parameters
773
+ ----------
774
+ config_dict : dict[str, object]
775
+ Configuration dictionary with training state.
776
+ """
777
+ # MagnitudeModel doesn't have additional training state to restore
778
+ pass
779
+
780
+ def _load_model_components(self, load_path: Path) -> None:
781
+ """Load model-specific components.
782
+
783
+ Parameters
784
+ ----------
785
+ load_path : Path
786
+ Directory to load from.
787
+ """
788
+ # Load config.json to reconstruct config
789
+ with open(load_path / "config.json") as f:
790
+ config_dict = json.load(f)
791
+
792
+ # Reconstruct MixedEffectsConfig if needed
793
+ if "mixed_effects" in config_dict and isinstance(
794
+ config_dict["mixed_effects"], dict
795
+ ):
796
+ config_dict["mixed_effects"] = MixedEffectsConfig(
797
+ **config_dict["mixed_effects"]
798
+ )
799
+
800
+ self.config = MagnitudeModelConfig(**config_dict)
801
+
802
+ self.encoder = AutoModel.from_pretrained(load_path / "encoder")
803
+ self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
804
+
805
+ self._initialize_regression_head()
806
+ self.regression_head.load_state_dict(
807
+ torch.load(
808
+ load_path / "regression_head.pt", map_location=self.config.device
809
+ )
810
+ )
811
+
812
+ self.encoder.to(self.config.device)
813
+ self.regression_head.to(self.config.device)
814
+
815
+ def _get_n_classes_for_random_effects(self) -> int:
816
+ """Get the number of classes for initializing RandomEffectsManager.
817
+
818
+ For magnitude models, this is 1 (scalar bias).
819
+
820
+ Returns
821
+ -------
822
+ int
823
+ Always 1 for regression.
824
+ """
825
+ return 1
826
+
827
+ def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
828
+ """Get the fixed head for random effects.
829
+
830
+ Returns
831
+ -------
832
+ torch.nn.Module | None
833
+ The regression head, or None if not applicable.
834
+ """
835
+ return self.regression_head