bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,811 @@
1
+ """Ordinal scale model for ordered rating scales (Likert, sliders, etc.).
2
+
3
+ Implements truncated normal distribution for bounded continuous responses on [0, 1].
4
+ Supports GLMM with participant-level random effects (intercepts and slopes).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import tempfile
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.distributions import Normal
17
+ from transformers import AutoModel, AutoTokenizer, TrainingArguments
18
+
19
+ from bead.active_learning.config import MixedEffectsConfig, VarianceComponents
20
+ from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
21
+ from bead.active_learning.models.random_effects import RandomEffectsManager
22
+ from bead.active_learning.trainers.data_collator import MixedEffectsDataCollator
23
+ from bead.active_learning.trainers.dataset_utils import items_to_dataset
24
+ from bead.active_learning.trainers.metrics import compute_regression_metrics
25
+ from bead.active_learning.trainers.model_wrapper import EncoderRegressionWrapper
26
+ from bead.config.active_learning import OrdinalScaleModelConfig
27
+ from bead.items.item import Item
28
+ from bead.items.item_template import ItemTemplate, TaskType
29
+
30
+ __all__ = ["OrdinalScaleModel"]
31
+
32
+
33
+ class OrdinalScaleModel(ActiveLearningModel):
34
+ """Model for ordinal_scale tasks with bounded continuous responses.
35
+
36
+ Uses truncated normal distribution on [scale_min, scale_max] to model
37
+ slider/Likert responses while properly handling endpoints (0 and 1).
38
+ Supports three modes: fixed effects, random intercepts, random slopes.
39
+
40
+ Parameters
41
+ ----------
42
+ config : OrdinalScaleModelConfig
43
+ Configuration object containing all model parameters.
44
+
45
+ Attributes
46
+ ----------
47
+ config : OrdinalScaleModelConfig
48
+ Model configuration.
49
+ tokenizer : AutoTokenizer
50
+ Transformer tokenizer.
51
+ encoder : AutoModel
52
+ Transformer encoder model.
53
+ regression_head : nn.Sequential
54
+ Regression head (fixed effects head) - outputs continuous μ.
55
+ random_effects : RandomEffectsManager
56
+ Manager for participant-level random effects.
57
+ variance_history : list[VarianceComponents]
58
+ Variance component estimates over training (for diagnostics).
59
+ _is_fitted : bool
60
+ Whether model has been trained.
61
+
62
+ Examples
63
+ --------
64
+ >>> from uuid import uuid4
65
+ >>> from bead.items.item import Item
66
+ >>> from bead.config.active_learning import OrdinalScaleModelConfig
67
+ >>> items = [
68
+ ... Item(
69
+ ... item_template_id=uuid4(),
70
+ ... rendered_elements={"text": f"Sentence {i}"}
71
+ ... )
72
+ ... for i in range(10)
73
+ ... ]
74
+ >>> labels = ["0.3", "0.7"] * 5 # Continuous values as strings
75
+ >>> config = OrdinalScaleModelConfig( # doctest: +SKIP
76
+ ... num_epochs=1, batch_size=2, device="cpu"
77
+ ... )
78
+ >>> model = OrdinalScaleModel(config=config) # doctest: +SKIP
79
+ >>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
80
+ >>> predictions = model.predict(items[:3], participant_ids=None) # doctest: +SKIP
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ config: OrdinalScaleModelConfig | None = None,
86
+ ) -> None:
87
+ """Initialize ordinal scale model.
88
+
89
+ Parameters
90
+ ----------
91
+ config : OrdinalScaleModelConfig | None
92
+ Configuration object. If None, uses default configuration.
93
+ """
94
+ self.config = config or OrdinalScaleModelConfig()
95
+
96
+ # Validate mixed_effects configuration
97
+ super().__init__(self.config)
98
+
99
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
100
+ self.encoder = AutoModel.from_pretrained(self.config.model_name)
101
+
102
+ self.regression_head: nn.Sequential | None = None
103
+ self._is_fitted = False
104
+
105
+ # Initialize random effects manager
106
+ self.random_effects: RandomEffectsManager | None = None
107
+ self.variance_history: list[VarianceComponents] = []
108
+
109
+ self.encoder.to(self.config.device)
110
+
111
+ @property
112
+ def supported_task_types(self) -> list[TaskType]:
113
+ """Get supported task types.
114
+
115
+ Returns
116
+ -------
117
+ list[TaskType]
118
+ List containing "ordinal_scale".
119
+ """
120
+ return ["ordinal_scale"]
121
+
122
+ def validate_item_compatibility(
123
+ self, item: Item, item_template: ItemTemplate
124
+ ) -> None:
125
+ """Validate item is compatible with ordinal scale model.
126
+
127
+ Parameters
128
+ ----------
129
+ item : Item
130
+ Item to validate.
131
+ item_template : ItemTemplate
132
+ Template the item was constructed from.
133
+
134
+ Raises
135
+ ------
136
+ ValueError
137
+ If task_type is not "ordinal_scale".
138
+ """
139
+ if item_template.task_type != "ordinal_scale":
140
+ raise ValueError(
141
+ f"Expected task_type 'ordinal_scale', got '{item_template.task_type}'"
142
+ )
143
+
144
+ def _initialize_regression_head(self) -> None:
145
+ """Initialize regression head for continuous output μ."""
146
+ hidden_size = self.encoder.config.hidden_size
147
+
148
+ # Single output for location parameter μ
149
+ self.regression_head = nn.Sequential(
150
+ nn.Linear(hidden_size, 256),
151
+ nn.ReLU(),
152
+ nn.Dropout(0.1),
153
+ nn.Linear(256, 1), # Output μ (location parameter)
154
+ )
155
+ self.regression_head.to(self.config.device)
156
+
157
+ def _encode_texts(self, texts: list[str]) -> torch.Tensor:
158
+ """Encode texts using transformer.
159
+
160
+ Parameters
161
+ ----------
162
+ texts : list[str]
163
+ Texts to encode.
164
+
165
+ Returns
166
+ -------
167
+ torch.Tensor
168
+ Encoded representations of shape (batch_size, hidden_size).
169
+ """
170
+ encodings = self.tokenizer(
171
+ texts,
172
+ padding=True,
173
+ truncation=True,
174
+ max_length=self.config.max_length,
175
+ return_tensors="pt",
176
+ )
177
+ encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
178
+
179
+ outputs = self.encoder(**encodings)
180
+ return outputs.last_hidden_state[:, 0, :]
181
+
182
+ def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
183
+ """Prepare inputs for encoding.
184
+
185
+ For ordinal scale tasks, concatenates all rendered elements.
186
+
187
+ Parameters
188
+ ----------
189
+ items : list[Item]
190
+ Items to encode.
191
+
192
+ Returns
193
+ -------
194
+ torch.Tensor
195
+ Encoded representations.
196
+ """
197
+ texts = []
198
+ for item in items:
199
+ # Concatenate all rendered elements
200
+ all_text = " ".join(item.rendered_elements.values())
201
+ texts.append(all_text)
202
+ return self._encode_texts(texts)
203
+
204
+ def _truncated_normal_log_prob(
205
+ self, y: torch.Tensor, mu: torch.Tensor, sigma: float
206
+ ) -> torch.Tensor:
207
+ """Compute log probability of truncated normal distribution.
208
+
209
+ Uses truncated normal on [scale_min, scale_max] to properly handle
210
+ endpoint responses (0.0 and 1.0) without arbitrary nudging.
211
+
212
+ Parameters
213
+ ----------
214
+ y : torch.Tensor
215
+ Observed values, shape (batch,).
216
+ mu : torch.Tensor
217
+ Location parameters (before truncation), shape (batch,).
218
+ sigma : float
219
+ Scale parameter (standard deviation).
220
+
221
+ Returns
222
+ -------
223
+ torch.Tensor
224
+ Log probabilities, shape (batch,).
225
+ """
226
+ base_dist = Normal(mu.squeeze(), sigma)
227
+
228
+ # Unnormalized log prob
229
+ log_prob_unnorm = base_dist.log_prob(y)
230
+
231
+ # Normalizer: log(Φ((high-μ)/σ) - Φ((low-μ)/σ))
232
+ alpha = (self.config.scale.min - mu.squeeze()) / sigma
233
+ beta = (self.config.scale.max - mu.squeeze()) / sigma
234
+ normalizer = base_dist.cdf(beta) - base_dist.cdf(alpha)
235
+
236
+ # Clamp to avoid log(0)
237
+ normalizer = torch.clamp(normalizer, min=1e-8)
238
+ log_normalizer = torch.log(normalizer)
239
+
240
+ return log_prob_unnorm - log_normalizer
241
+
242
+ def _prepare_training_data(
243
+ self,
244
+ items: list[Item],
245
+ labels: list[str],
246
+ participant_ids: list[str],
247
+ validation_items: list[Item] | None,
248
+ validation_labels: list[str] | None,
249
+ ) -> tuple[
250
+ list[Item], list[float], list[str], list[Item] | None, list[float] | None
251
+ ]:
252
+ """Prepare training data for ordinal scale model.
253
+
254
+ Parameters
255
+ ----------
256
+ items : list[Item]
257
+ Training items.
258
+ labels : list[str]
259
+ Training labels (continuous values as strings).
260
+ participant_ids : list[str]
261
+ Normalized participant IDs.
262
+ validation_items : list[Item] | None
263
+ Validation items.
264
+ validation_labels : list[str] | None
265
+ Validation labels.
266
+
267
+ Returns
268
+ -------
269
+ tuple[list[Item], list[float], list[str], list[Item] | None, list[float] | None]
270
+ Prepared items, numeric labels (floats), participant_ids,
271
+ validation_items, numeric validation_labels.
272
+ """
273
+ # Parse labels to floats and validate bounds
274
+ try:
275
+ y_values = [float(label) for label in labels]
276
+ except ValueError as e:
277
+ raise ValueError(
278
+ f"Labels must be numeric strings (e.g., '0.5', '0.75'). Got error: {e}"
279
+ ) from e
280
+
281
+ # Validate all values are within bounds
282
+ for i, val in enumerate(y_values):
283
+ if not (self.config.scale.min <= val <= self.config.scale.max):
284
+ raise ValueError(
285
+ f"Label at index {i} ({val}) is outside bounds "
286
+ f"[{self.config.scale.min}, {self.config.scale.max}]"
287
+ )
288
+
289
+ self._initialize_regression_head()
290
+
291
+ # Convert validation labels if provided
292
+ val_y_numeric = None
293
+ if validation_items is not None and validation_labels is not None:
294
+ try:
295
+ val_y_numeric = [float(label) for label in validation_labels]
296
+ except ValueError as e:
297
+ raise ValueError(
298
+ f"Validation labels must be numeric strings. Got error: {e}"
299
+ ) from e
300
+
301
+ # Validate bounds for validation labels
302
+ for i, val in enumerate(val_y_numeric):
303
+ if not (self.config.scale.min <= val <= self.config.scale.max):
304
+ raise ValueError(
305
+ f"Validation label at index {i} ({val}) is outside bounds "
306
+ f"[{self.config.scale.min}, {self.config.scale.max}]"
307
+ )
308
+
309
+ return items, y_values, participant_ids, validation_items, val_y_numeric
310
+
311
+ def _initialize_random_effects(self, n_classes: int) -> None:
312
+ """Initialize random effects manager.
313
+
314
+ Parameters
315
+ ----------
316
+ n_classes : int
317
+ Number of classes (1 for regression).
318
+ """
319
+ self.random_effects = RandomEffectsManager(
320
+ self.config.mixed_effects,
321
+ n_classes=n_classes, # Scalar bias for μ
322
+ )
323
+
324
+ def _do_training(
325
+ self,
326
+ items: list[Item],
327
+ labels_numeric: list[float],
328
+ participant_ids: list[str],
329
+ validation_items: list[Item] | None,
330
+ validation_labels_numeric: list[float] | None,
331
+ ) -> dict[str, float]:
332
+ """Perform ordinal scale model training.
333
+
334
+ Parameters
335
+ ----------
336
+ items : list[Item]
337
+ Training items.
338
+ labels_numeric : list[float]
339
+ Numeric labels (continuous values).
340
+ participant_ids : list[str]
341
+ Participant IDs.
342
+ validation_items : list[Item] | None
343
+ Validation items.
344
+ validation_labels_numeric : list[float] | None
345
+ Numeric validation labels.
346
+
347
+ Returns
348
+ -------
349
+ dict[str, float]
350
+ Training metrics.
351
+ """
352
+ # Convert validation_labels_numeric back to string labels for validation metrics
353
+ validation_labels = None
354
+ if validation_items is not None and validation_labels_numeric is not None:
355
+ validation_labels = [str(val) for val in validation_labels_numeric]
356
+
357
+ # Use HuggingFace Trainer for fixed and random_intercepts modes
358
+ # random_slopes requires custom loop due to per-participant heads
359
+ use_huggingface_trainer = self.config.mixed_effects.mode in (
360
+ "fixed",
361
+ "random_intercepts",
362
+ )
363
+
364
+ if use_huggingface_trainer:
365
+ metrics = self._train_with_huggingface_trainer(
366
+ items,
367
+ labels_numeric,
368
+ participant_ids,
369
+ validation_items,
370
+ validation_labels,
371
+ )
372
+ else:
373
+ # Use custom training loop for random_slopes
374
+ metrics = self._train_with_custom_loop(
375
+ items,
376
+ labels_numeric,
377
+ participant_ids,
378
+ validation_items,
379
+ validation_labels,
380
+ )
381
+
382
+ # Add validation MSE if validation data provided and not already computed
383
+ if (
384
+ validation_items is not None
385
+ and validation_labels is not None
386
+ and "val_mse" not in metrics
387
+ ):
388
+ # Validation with placeholder participant_ids for mixed effects
389
+ if self.config.mixed_effects.mode == "fixed":
390
+ val_participant_ids = ["_fixed_"] * len(validation_items)
391
+ else:
392
+ val_participant_ids = ["_validation_"] * len(validation_items)
393
+ val_predictions = self._do_predict(validation_items, val_participant_ids)
394
+ val_pred_values = [float(p.predicted_class) for p in val_predictions]
395
+ val_true_values = [float(label) for label in validation_labels]
396
+ val_mse = np.mean(
397
+ [
398
+ (pred - true) ** 2
399
+ for pred, true in zip(val_pred_values, val_true_values, strict=True)
400
+ ]
401
+ )
402
+ metrics["val_mse"] = val_mse
403
+
404
+ return metrics
405
+
406
+ def _train_with_huggingface_trainer(
407
+ self,
408
+ items: list[Item],
409
+ y_numeric: list[float],
410
+ participant_ids: list[str],
411
+ validation_items: list[Item] | None,
412
+ validation_labels: list[str] | None,
413
+ ) -> dict[str, float]:
414
+ """Train using HuggingFace Trainer with mixed effects support for regression.
415
+
416
+ Parameters
417
+ ----------
418
+ items : list[Item]
419
+ Training items.
420
+ y_numeric : list[float]
421
+ Numeric labels (continuous values).
422
+ participant_ids : list[str]
423
+ Participant IDs.
424
+ validation_items : list[Item] | None
425
+ Validation items.
426
+ validation_labels : list[str] | None
427
+ Validation labels.
428
+
429
+ Returns
430
+ -------
431
+ dict[str, float]
432
+ Training metrics.
433
+ """
434
+ # Convert items to HuggingFace Dataset
435
+ train_dataset = items_to_dataset(
436
+ items=items,
437
+ labels=y_numeric,
438
+ participant_ids=participant_ids,
439
+ tokenizer=self.tokenizer,
440
+ max_length=self.config.max_length,
441
+ )
442
+
443
+ eval_dataset = None
444
+ if validation_items is not None and validation_labels is not None:
445
+ val_y_numeric = [float(label) for label in validation_labels]
446
+ val_participant_ids = (
447
+ ["_validation_"] * len(validation_items)
448
+ if self.config.mixed_effects.mode != "fixed"
449
+ else ["_fixed_"] * len(validation_items)
450
+ )
451
+ eval_dataset = items_to_dataset(
452
+ items=validation_items,
453
+ labels=val_y_numeric,
454
+ participant_ids=val_participant_ids,
455
+ tokenizer=self.tokenizer,
456
+ max_length=self.config.max_length,
457
+ )
458
+
459
+ # Wrap the encoder and regression head for Trainer
460
+ wrapped_model = EncoderRegressionWrapper(
461
+ encoder=self.encoder, regression_head=self.regression_head
462
+ )
463
+
464
+ # Create data collator
465
+ data_collator = MixedEffectsDataCollator(tokenizer=self.tokenizer)
466
+
467
+ # Create training arguments with checkpointing
468
+ with tempfile.TemporaryDirectory() as tmpdir:
469
+ checkpoint_dir = Path(tmpdir) / "checkpoints"
470
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
471
+
472
+ training_args = TrainingArguments(
473
+ output_dir=str(checkpoint_dir),
474
+ num_train_epochs=self.config.num_epochs,
475
+ per_device_train_batch_size=self.config.batch_size,
476
+ per_device_eval_batch_size=self.config.batch_size,
477
+ learning_rate=self.config.learning_rate,
478
+ logging_steps=10,
479
+ eval_strategy="epoch" if eval_dataset is not None else "no",
480
+ save_strategy="epoch",
481
+ save_total_limit=1,
482
+ load_best_model_at_end=False,
483
+ report_to="none",
484
+ remove_unused_columns=False,
485
+ use_cpu=self.config.device == "cpu",
486
+ )
487
+
488
+ # Import here to avoid circular import
489
+ from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
490
+ MixedEffectsTrainer,
491
+ )
492
+
493
+ # Create trainer
494
+ trainer = MixedEffectsTrainer(
495
+ model=wrapped_model,
496
+ args=training_args,
497
+ train_dataset=train_dataset,
498
+ eval_dataset=eval_dataset,
499
+ data_collator=data_collator,
500
+ tokenizer=self.tokenizer,
501
+ random_effects_manager=self.random_effects,
502
+ compute_metrics=compute_regression_metrics,
503
+ )
504
+
505
+ # Train
506
+ train_result = trainer.train()
507
+
508
+ # Get training metrics
509
+ train_metrics = trainer.evaluate(eval_dataset=train_dataset)
510
+ metrics: dict[str, float] = {
511
+ "train_loss": float(train_result.training_loss),
512
+ "train_mse": train_metrics.get("eval_mse", 0.0),
513
+ "train_mae": train_metrics.get("eval_mae", 0.0),
514
+ "train_r2": train_metrics.get("eval_r2", 0.0),
515
+ }
516
+
517
+ # Get validation metrics if eval_dataset was provided
518
+ if eval_dataset is not None:
519
+ val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
520
+ metrics.update(
521
+ {
522
+ "val_mse": val_metrics.get("eval_mse", 0.0),
523
+ "val_mae": val_metrics.get("eval_mae", 0.0),
524
+ "val_r2": val_metrics.get("eval_r2", 0.0),
525
+ }
526
+ )
527
+
528
+ return metrics
529
+
530
+ def _train_with_custom_loop(
531
+ self,
532
+ items: list[Item],
533
+ y_numeric: list[float],
534
+ participant_ids: list[str],
535
+ validation_items: list[Item] | None,
536
+ validation_labels: list[str] | None,
537
+ ) -> dict[str, float]:
538
+ """Train using custom loop for random_slopes mode.
539
+
540
+ Parameters
541
+ ----------
542
+ items : list[Item]
543
+ Training items.
544
+ y_numeric : list[float]
545
+ Numeric labels (continuous values).
546
+ participant_ids : list[str]
547
+ Participant IDs.
548
+ validation_items : list[Item] | None
549
+ Validation items.
550
+ validation_labels : list[str] | None
551
+ Validation labels.
552
+
553
+ Returns
554
+ -------
555
+ dict[str, float]
556
+ Training metrics.
557
+ """
558
+ y = torch.tensor(y_numeric, dtype=torch.float, device=self.config.device)
559
+
560
+ # Build optimizer parameters
561
+ params_to_optimize = list(self.encoder.parameters()) + list(
562
+ self.regression_head.parameters()
563
+ )
564
+
565
+ # Add random effects parameters for random_slopes
566
+ for head in self.random_effects.slopes.values():
567
+ params_to_optimize.extend(head.parameters())
568
+
569
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
570
+
571
+ self.encoder.train()
572
+ self.regression_head.train()
573
+
574
+ for _epoch in range(self.config.num_epochs):
575
+ n_batches = (
576
+ len(items) + self.config.batch_size - 1
577
+ ) // self.config.batch_size
578
+ epoch_loss = 0.0
579
+ epoch_mse = 0.0
580
+
581
+ for i in range(n_batches):
582
+ start_idx = i * self.config.batch_size
583
+ end_idx = min(start_idx + self.config.batch_size, len(items))
584
+
585
+ batch_items = items[start_idx:end_idx]
586
+ batch_labels = y[start_idx:end_idx]
587
+ batch_participant_ids = participant_ids[start_idx:end_idx]
588
+
589
+ embeddings = self._prepare_inputs(batch_items)
590
+
591
+ # Per-participant head for random_slopes
592
+ mu_list = []
593
+ for j, pid in enumerate(batch_participant_ids):
594
+ participant_head = self.random_effects.get_slopes(
595
+ pid,
596
+ fixed_head=self.regression_head,
597
+ create_if_missing=True,
598
+ )
599
+ mu_j = participant_head(embeddings[j : j + 1]).squeeze()
600
+ mu_list.append(mu_j)
601
+ mu = torch.stack(mu_list)
602
+
603
+ # Negative log-likelihood of truncated normal
604
+ log_probs = self._truncated_normal_log_prob(
605
+ batch_labels, mu, self.config.sigma
606
+ )
607
+ loss_nll = -log_probs.mean()
608
+
609
+ # Add prior regularization
610
+ loss_prior = self.random_effects.compute_prior_loss()
611
+ loss = loss_nll + loss_prior
612
+
613
+ optimizer.zero_grad()
614
+ loss.backward()
615
+ optimizer.step()
616
+
617
+ epoch_loss += loss.item()
618
+ # Also track MSE for interpretability
619
+ mse = ((mu - batch_labels) ** 2).mean().item()
620
+ epoch_mse += mse
621
+
622
+ epoch_loss = epoch_loss / n_batches
623
+ epoch_mse = epoch_mse / n_batches
624
+
625
+ metrics: dict[str, float] = {
626
+ "train_loss": epoch_loss,
627
+ "train_mse": epoch_mse,
628
+ }
629
+
630
+ return metrics
631
+
632
+ def _do_predict(
633
+ self, items: list[Item], participant_ids: list[str]
634
+ ) -> list[ModelPrediction]:
635
+ """Perform ordinal scale model prediction.
636
+
637
+ Parameters
638
+ ----------
639
+ items : list[Item]
640
+ Items to predict.
641
+ participant_ids : list[str]
642
+ Normalized participant IDs.
643
+
644
+ Returns
645
+ -------
646
+ list[ModelPrediction]
647
+ Predictions with predicted_class as string representation of value.
648
+ """
649
+ self.encoder.eval()
650
+ self.regression_head.eval()
651
+
652
+ with torch.no_grad():
653
+ embeddings = self._prepare_inputs(items)
654
+
655
+ # Forward pass depends on mixed effects mode
656
+ if self.config.mixed_effects.mode == "fixed":
657
+ mu = self.regression_head(embeddings).squeeze(1)
658
+
659
+ elif self.config.mixed_effects.mode == "random_intercepts":
660
+ mu = self.regression_head(embeddings).squeeze(1)
661
+ for i, pid in enumerate(participant_ids):
662
+ # Unknown participants: use prior mean (zero bias)
663
+ bias = self.random_effects.get_intercepts(
664
+ pid, n_classes=1, param_name="mu", create_if_missing=False
665
+ )
666
+ mu[i] = mu[i] + bias.item()
667
+
668
+ elif self.config.mixed_effects.mode == "random_slopes":
669
+ mu_list = []
670
+ for i, pid in enumerate(participant_ids):
671
+ # Unknown participants: use fixed head
672
+ participant_head = self.random_effects.get_slopes(
673
+ pid, fixed_head=self.regression_head, create_if_missing=False
674
+ )
675
+ mu_i = participant_head(embeddings[i : i + 1]).squeeze()
676
+ mu_list.append(mu_i)
677
+ mu = torch.stack(mu_list)
678
+
679
+ # Clamp predictions to bounds
680
+ mu = torch.clamp(mu, self.config.scale.min, self.config.scale.max)
681
+ predictions_array = mu.cpu().numpy()
682
+
683
+ predictions = []
684
+ for i, item in enumerate(items):
685
+ pred_value = float(predictions_array[i])
686
+ predictions.append(
687
+ ModelPrediction(
688
+ item_id=str(item.id),
689
+ probabilities={}, # Not applicable for regression
690
+ predicted_class=str(pred_value), # Continuous value as string
691
+ confidence=1.0, # Not applicable for regression
692
+ )
693
+ )
694
+
695
+ return predictions
696
+
697
+ def _do_predict_proba(
698
+ self, items: list[Item], participant_ids: list[str]
699
+ ) -> np.ndarray:
700
+ """Perform ordinal scale model probability prediction.
701
+
702
+ For ordinal scale regression, returns μ values directly.
703
+
704
+ Parameters
705
+ ----------
706
+ items : list[Item]
707
+ Items to predict.
708
+ participant_ids : list[str]
709
+ Normalized participant IDs.
710
+
711
+ Returns
712
+ -------
713
+ np.ndarray
714
+ Array of shape (n_items, 1) with predicted μ values.
715
+ """
716
+ predictions = self._do_predict(items, participant_ids)
717
+ return np.array([[float(p.predicted_class)] for p in predictions])
718
+
719
+ def _save_model_components(self, save_path: Path) -> None:
720
+ """Save model-specific components.
721
+
722
+ Parameters
723
+ ----------
724
+ save_path : Path
725
+ Directory to save to.
726
+ """
727
+ self.encoder.save_pretrained(save_path / "encoder")
728
+ self.tokenizer.save_pretrained(save_path / "encoder")
729
+
730
+ torch.save(
731
+ self.regression_head.state_dict(),
732
+ save_path / "regression_head.pt",
733
+ )
734
+
735
+ def _get_save_state(self) -> dict[str, object]:
736
+ """Get model-specific state to save.
737
+
738
+ Returns
739
+ -------
740
+ dict[str, object]
741
+ State dictionary.
742
+ """
743
+ return {}
744
+
745
+ def _restore_training_state(self, config_dict: dict[str, object]) -> None:
746
+ """Restore model-specific training state.
747
+
748
+ Parameters
749
+ ----------
750
+ config_dict : dict[str, object]
751
+ Configuration dictionary with training state.
752
+ """
753
+ # OrdinalScaleModel doesn't have additional training state to restore
754
+ pass
755
+
756
+ def _load_model_components(self, load_path: Path) -> None:
757
+ """Load model-specific components.
758
+
759
+ Parameters
760
+ ----------
761
+ load_path : Path
762
+ Directory to load from.
763
+ """
764
+ # Load config.json to reconstruct config
765
+ with open(load_path / "config.json") as f:
766
+ config_dict = json.load(f)
767
+
768
+ # Reconstruct MixedEffectsConfig if needed
769
+ if "mixed_effects" in config_dict and isinstance(
770
+ config_dict["mixed_effects"], dict
771
+ ):
772
+ config_dict["mixed_effects"] = MixedEffectsConfig(
773
+ **config_dict["mixed_effects"]
774
+ )
775
+
776
+ self.config = OrdinalScaleModelConfig(**config_dict)
777
+
778
+ self.encoder = AutoModel.from_pretrained(load_path / "encoder")
779
+ self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
780
+
781
+ self._initialize_regression_head()
782
+ self.regression_head.load_state_dict(
783
+ torch.load(
784
+ load_path / "regression_head.pt", map_location=self.config.device
785
+ )
786
+ )
787
+
788
+ self.encoder.to(self.config.device)
789
+ self.regression_head.to(self.config.device)
790
+
791
+ def _get_n_classes_for_random_effects(self) -> int:
792
+ """Get the number of classes for initializing RandomEffectsManager.
793
+
794
+ For ordinal scale models, this is 1 (scalar bias).
795
+
796
+ Returns
797
+ -------
798
+ int
799
+ Always 1 for regression.
800
+ """
801
+ return 1
802
+
803
+ def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
804
+ """Get the fixed head for random effects.
805
+
806
+ Returns
807
+ -------
808
+ torch.nn.Module | None
809
+ The regression head, or None if not applicable.
810
+ """
811
+ return self.regression_head