bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,639 @@
1
+ """Manager for random effects in GLMM-based active learning.
2
+
3
+ Implements:
4
+ - Random effect storage and retrieval (intercepts and slopes)
5
+ - Variance component estimation (G matrix via MLE/REML)
6
+ - Empirical Bayes shrinkage for small groups
7
+ - Adaptive regularization based on sample counts
8
+ - Save/load with variance component history
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import copy
14
+ import json
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from bead.active_learning.config import MixedEffectsConfig, VarianceComponents
23
+
24
+ __all__ = ["RandomEffectsManager"]
25
+
26
+
27
+ class RandomEffectsManager:
28
+ """Manages random effects following GLMM theory: u ~ N(0, G).
29
+
30
+ Core responsibilities:
31
+ 1. Store random effect values: u_i for each participant i
32
+ 2. Estimate variance components: σ²_u (the G matrix)
33
+ 3. Implement shrinkage: u_shrunk_i = λ_i * u_i + (1-λ_i) * μ_0
34
+ 4. Compute prior loss: L_prior = λ * Σ_i w_i * ||u_i - μ_0||²
35
+ 5. Handle unknown participants: Use population mean (μ_0)
36
+
37
+ Attributes
38
+ ----------
39
+ config : MixedEffectsConfig
40
+ Configuration including mode, priors, regularization.
41
+ intercepts : dict[str, torch.Tensor]
42
+ Random intercepts per participant.
43
+ Key: participant_id, Value: bias vector of shape (n_classes,)
44
+ slopes : dict[str, nn.Module]
45
+ Random slopes per participant.
46
+ Key: participant_id, Value: model head (nn.Module)
47
+ participant_sample_counts : dict[str, int]
48
+ Training samples per participant (for adaptive regularization).
49
+ variance_components : VarianceComponents | None
50
+ Latest variance component estimates.
51
+ variance_history : list[VarianceComponents]
52
+ Variance components over training (for diagnostics).
53
+
54
+ Examples
55
+ --------
56
+ >>> config = MixedEffectsConfig(mode='random_intercepts')
57
+ >>> manager = RandomEffectsManager(config, n_classes=3)
58
+
59
+ >>> # Register participants during training
60
+ >>> manager.register_participant("alice", n_samples=10)
61
+ >>> manager.register_participant("bob", n_samples=15)
62
+
63
+ >>> # Get intercepts (creates if missing)
64
+ >>> bias_alice = manager.get_intercepts("alice", n_classes=3)
65
+
66
+ >>> # Estimate variance components after training
67
+ >>> var_comp = manager.estimate_variance_components()
68
+ >>> print(f"σ²_u = {var_comp.variance:.3f}")
69
+
70
+ >>> # Compute prior loss for regularization
71
+ >>> loss_prior = manager.compute_prior_loss()
72
+ """
73
+
74
+ def __init__(self, config: MixedEffectsConfig, **kwargs: Any) -> None:
75
+ """Initialize random effects manager.
76
+
77
+ Parameters
78
+ ----------
79
+ config : MixedEffectsConfig
80
+ GLMM configuration.
81
+ **kwargs : Any
82
+ Additional arguments (e.g., n_classes, hidden_dim).
83
+ Required arguments depend on mode.
84
+
85
+ Raises
86
+ ------
87
+ ValueError
88
+ If mode='random_slopes' but required kwargs missing.
89
+ """
90
+ self.config = config
91
+ # Nested dict structure: intercepts[param_name][participant_id] = tensor
92
+ # Examples:
93
+ # intercepts["mu"]["alice"] = tensor([0.12])
94
+ # intercepts["cutpoint_1"]["alice"] = tensor([0.05])
95
+ self.intercepts: dict[str, dict[str, torch.Tensor]] = {}
96
+ self.slopes: dict[str, nn.Module] = {}
97
+ self.participant_sample_counts: dict[str, int] = {}
98
+
99
+ self.variance_components: VarianceComponents | None = None
100
+ self.variance_history: list[VarianceComponents] = []
101
+
102
+ # Store kwargs for creating new random effects
103
+ self.creation_kwargs = kwargs
104
+
105
+ def register_participant(self, participant_id: str, n_samples: int) -> None:
106
+ """Register participant and track sample count.
107
+
108
+ Used for:
109
+ - Adaptive regularization (fewer samples → stronger regularization)
110
+ - Shrinkage estimation (fewer samples → shrink toward mean)
111
+ - Variance component estimation
112
+
113
+ Parameters
114
+ ----------
115
+ participant_id : str
116
+ Participant identifier.
117
+ n_samples : int
118
+ Number of samples for this participant.
119
+
120
+ Raises
121
+ ------
122
+ ValueError
123
+ If participant_id empty or n_samples not positive.
124
+
125
+ Examples
126
+ --------
127
+ >>> manager.register_participant("alice", n_samples=10)
128
+ >>> manager.register_participant("bob", n_samples=15)
129
+ """
130
+ if not participant_id:
131
+ raise ValueError(
132
+ "participant_id cannot be empty. "
133
+ "Ensure all participants have valid string identifiers."
134
+ )
135
+ if n_samples <= 0:
136
+ raise ValueError(
137
+ f"n_samples must be positive, got {n_samples}. "
138
+ f"Each participant must have at least 1 sample."
139
+ )
140
+
141
+ # Accumulate samples if participant seen before
142
+ if participant_id in self.participant_sample_counts:
143
+ self.participant_sample_counts[participant_id] += n_samples
144
+ else:
145
+ self.participant_sample_counts[participant_id] = n_samples
146
+
147
+ def get_intercepts(
148
+ self,
149
+ participant_id: str,
150
+ n_classes: int,
151
+ param_name: str,
152
+ create_if_missing: bool = True,
153
+ ) -> torch.Tensor:
154
+ """Get random intercepts for specific distribution parameter.
155
+
156
+ Parameters
157
+ ----------
158
+ participant_id : str
159
+ Participant identifier.
160
+ n_classes : int
161
+ Number of classes (length of bias vector).
162
+ param_name : str
163
+ Name of the distribution parameter (e.g., "mu", "cutpoint_1", "cutpoint_2").
164
+ create_if_missing : bool, default=True
165
+ Whether to create new intercepts for unknown participants.
166
+ True: Training (create new random effects)
167
+ False: Prediction (use prior mean for unknown)
168
+
169
+ Returns
170
+ -------
171
+ torch.Tensor
172
+ Bias vector of shape (n_classes,).
173
+
174
+ Raises
175
+ ------
176
+ ValueError
177
+ If mode is not 'random_intercepts'.
178
+
179
+ Examples
180
+ --------
181
+ >>> bias = manager.get_intercepts("alice", n_classes=3, param_name="mu")
182
+ >>> bias.shape
183
+ torch.Size([3])
184
+
185
+ >>> # Multi-parameter: Ordered beta
186
+ >>> mu_bias = manager.get_intercepts("alice", 1, param_name="mu")
187
+ >>> c1_bias = manager.get_intercepts("alice", 1, param_name="cutpoint_1")
188
+ """
189
+ if self.config.mode != "random_intercepts":
190
+ raise ValueError(
191
+ f"get_intercepts() called but mode is '{self.config.mode}', "
192
+ f"expected 'random_intercepts'. "
193
+ f"Use mode='random_intercepts' in MixedEffectsConfig."
194
+ )
195
+
196
+ # Initialize parameter dict if first time seeing this parameter
197
+ if param_name not in self.intercepts:
198
+ self.intercepts[param_name] = {}
199
+
200
+ param_dict = self.intercepts[param_name]
201
+
202
+ # Known participant: return learned intercepts
203
+ if participant_id in param_dict:
204
+ return param_dict[participant_id]
205
+
206
+ # Unknown participant: use prior mean
207
+ if not create_if_missing:
208
+ return torch.zeros(n_classes) + self.config.prior_mean
209
+
210
+ # Create new intercepts from prior: u_i ~ N(μ_0, σ²_0)
211
+ bias = (
212
+ torch.randn(n_classes) * np.sqrt(self.config.prior_variance)
213
+ + self.config.prior_mean
214
+ )
215
+ bias.requires_grad = True
216
+ param_dict[participant_id] = bias
217
+ return bias
218
+
219
+ def get_intercepts_with_shrinkage(
220
+ self, participant_id: str, n_classes: int, param_name: str = "bias"
221
+ ) -> torch.Tensor:
222
+ """Get random intercepts with Empirical Bayes shrinkage.
223
+
224
+ Implements shrinkage toward population mean:
225
+
226
+ u_shrunk_i = λ_i * u_mle_i + (1 - λ_i) * μ_0
227
+
228
+ where:
229
+ λ_i = n_i / (n_i + k)
230
+ k ≈ σ²_ε / σ²_u (ratio of residual to random effect variance)
231
+
232
+ For participants with few samples, shrink toward μ_0 (population mean).
233
+ For participants with many samples, use their specific estimate.
234
+
235
+ Parameters
236
+ ----------
237
+ participant_id : str
238
+ Participant identifier.
239
+ n_classes : int
240
+ Number of classes.
241
+ param_name : str, default="bias"
242
+ Name of the distribution parameter.
243
+
244
+ Returns
245
+ -------
246
+ torch.Tensor
247
+ Shrunk bias vector of shape (n_classes,).
248
+
249
+ Examples
250
+ --------
251
+ >>> # Participant with 2 samples → strong shrinkage
252
+ >>> manager.register_participant("alice", n_samples=2)
253
+ >>> bias_shrunk = manager.get_intercepts_with_shrinkage("alice", 3)
254
+
255
+ >>> # Participant with 100 samples → little shrinkage
256
+ >>> manager.register_participant("bob", n_samples=100)
257
+ >>> bias_shrunk_bob = manager.get_intercepts_with_shrinkage("bob", 3)
258
+ """
259
+ if self.config.mode != "random_intercepts":
260
+ raise ValueError(
261
+ f"Shrinkage only for random_intercepts mode, got '{self.config.mode}'"
262
+ )
263
+
264
+ # Get MLE estimate (or prior if unknown)
265
+ u_mle = self.get_intercepts(
266
+ participant_id, n_classes, param_name, create_if_missing=False
267
+ )
268
+
269
+ # Unknown participant: return prior mean (no shrinkage needed)
270
+ param_dict = self.intercepts.get(param_name, {})
271
+ if participant_id not in param_dict:
272
+ return u_mle
273
+
274
+ # Compute shrinkage factor λ_i
275
+ n_i = self.participant_sample_counts.get(participant_id, 1)
276
+
277
+ # Estimate k from variance components if available
278
+ if self.variance_components is not None:
279
+ sigma2_u = self.variance_components.variance
280
+ # Estimate σ²_ε from residuals (simplified: assume σ²_ε ≈ 1)
281
+ sigma2_epsilon = 1.0
282
+ k = sigma2_epsilon / max(sigma2_u, 1e-6)
283
+ else:
284
+ # Fallback: use min_samples as proxy for k
285
+ k = self.config.min_samples_for_random_effects
286
+
287
+ lambda_i = n_i / (n_i + k)
288
+
289
+ # Shrinkage: u_shrunk = λ * u_mle + (1-λ) * μ_0
290
+ mu_0 = self.config.prior_mean
291
+ u_shrunk = lambda_i * u_mle + (1 - lambda_i) * mu_0
292
+
293
+ return u_shrunk
294
+
295
+ def get_slopes(
296
+ self,
297
+ participant_id: str,
298
+ fixed_head: nn.Module,
299
+ create_if_missing: bool = True,
300
+ ) -> nn.Module:
301
+ """Get random slopes (model head) for participant.
302
+
303
+ Behavior:
304
+ - Known participant: Return learned head
305
+ - Unknown participant:
306
+ - If create_if_missing=True: Clone fixed_head and add noise
307
+ - If create_if_missing=False: Return clone of fixed_head
308
+
309
+ Parameters
310
+ ----------
311
+ participant_id : str
312
+ Participant identifier.
313
+ fixed_head : nn.Module
314
+ Fixed effects head to clone for new participants.
315
+ create_if_missing : bool, default=True
316
+ Whether to create new slopes for unknown participants.
317
+
318
+ Returns
319
+ -------
320
+ nn.Module
321
+ Model head for this participant.
322
+
323
+ Raises
324
+ ------
325
+ ValueError
326
+ If mode is not 'random_slopes'.
327
+
328
+ Examples
329
+ --------
330
+ >>> fixed_head = nn.Linear(768, 3)
331
+ >>> # Training: Create participant-specific head
332
+ >>> head_alice = manager.get_slopes("alice", fixed_head, create_if_missing=True)
333
+
334
+ >>> # Prediction: Use fixed head for unknown
335
+ >>> head_unknown = manager.get_slopes(
336
+ ... "unknown", fixed_head, create_if_missing=False
337
+ ... )
338
+ """
339
+ if self.config.mode != "random_slopes":
340
+ raise ValueError(
341
+ f"get_slopes() called but mode is '{self.config.mode}', "
342
+ f"expected 'random_slopes'"
343
+ )
344
+
345
+ # Known participant: return learned slopes
346
+ if participant_id in self.slopes:
347
+ return self.slopes[participant_id]
348
+
349
+ # Unknown participant: return clone of fixed head
350
+ if not create_if_missing:
351
+ return copy.deepcopy(fixed_head)
352
+
353
+ # Create new slopes: φ_i = θ + noise
354
+ # Clone fixed head and add Gaussian noise to parameters
355
+ participant_head = copy.deepcopy(fixed_head)
356
+
357
+ with torch.no_grad():
358
+ for param in participant_head.parameters():
359
+ noise = torch.randn_like(param) * np.sqrt(self.config.prior_variance)
360
+ param.add_(noise)
361
+
362
+ self.slopes[participant_id] = participant_head
363
+ return participant_head
364
+
365
+ def estimate_variance_components(
366
+ self,
367
+ ) -> dict[str, VarianceComponents] | None:
368
+ """Estimate variance components (G matrix) from random effects.
369
+
370
+ Returns
371
+ -------
372
+ dict[str, VarianceComponents] | None
373
+ Dictionary mapping param_name -> VarianceComponents.
374
+ For single-parameter models (most common), returns dict with one key.
375
+ For multi-parameter models (e.g., ordered beta), returns dict
376
+ with multiple keys.
377
+ Returns None if mode='fixed' or no random_slopes.
378
+
379
+ Examples
380
+ --------
381
+ >>> # Single parameter (most common)
382
+ >>> var_comps = manager.estimate_variance_components()
383
+ >>> print(f"Mu variance: {var_comps['mu'].variance:.3f}")
384
+
385
+ >>> # Multi-parameter (ordered beta)
386
+ >>> var_comps = manager.estimate_variance_components()
387
+ >>> print(f"Mu variance: {var_comps['mu'].variance:.3f}")
388
+ >>> print(f"Cutpoint_1 variance: {var_comps['cutpoint_1'].variance:.3f}")
389
+ """
390
+ if self.config.mode == "fixed":
391
+ return None
392
+
393
+ if self.config.mode == "random_intercepts":
394
+ if not self.intercepts:
395
+ return None
396
+
397
+ variance_components: dict[str, VarianceComponents] = {}
398
+ for param_name, param_intercepts in self.intercepts.items():
399
+ if not param_intercepts:
400
+ continue
401
+
402
+ all_intercepts = torch.stack(list(param_intercepts.values()))
403
+ if len(param_intercepts) == 1:
404
+ variance = 0.0
405
+ else:
406
+ variance = torch.var(all_intercepts, unbiased=True).item()
407
+
408
+ variance_components[param_name] = VarianceComponents(
409
+ grouping_factor="participant",
410
+ effect_type="intercept",
411
+ variance=variance,
412
+ n_groups=len(param_intercepts),
413
+ n_observations_per_group=self.participant_sample_counts.copy(),
414
+ )
415
+
416
+ # Update variance_components and history
417
+ self.variance_components = variance_components
418
+ # Store the first param's variance in history for backwards compatibility
419
+ first_param = next(iter(variance_components.values()))
420
+ self.variance_history.append(first_param)
421
+
422
+ return variance_components
423
+
424
+ elif self.config.mode == "random_slopes":
425
+ if not self.slopes:
426
+ return None
427
+
428
+ all_params: list[torch.Tensor] = []
429
+ for head in self.slopes.values():
430
+ params_flat = torch.cat([p.flatten() for p in head.parameters()])
431
+ all_params.append(params_flat)
432
+
433
+ all_params_tensor = torch.stack(all_params)
434
+ if len(self.slopes) == 1:
435
+ variance = 0.0
436
+ else:
437
+ variance = torch.var(all_params_tensor, unbiased=True).item()
438
+
439
+ # Random slopes still returns single variance component (not per-parameter)
440
+ slope_var_comp = VarianceComponents(
441
+ grouping_factor="participant",
442
+ effect_type="slope",
443
+ variance=variance,
444
+ n_groups=len(self.slopes),
445
+ n_observations_per_group=self.participant_sample_counts.copy(),
446
+ )
447
+ result = {"slopes": slope_var_comp}
448
+
449
+ # Update variance_components and history
450
+ self.variance_components = result
451
+ self.variance_history.append(slope_var_comp)
452
+
453
+ return result
454
+
455
+ return None
456
+
457
+ def compute_prior_loss(self) -> torch.Tensor:
458
+ """Compute regularization loss toward prior.
459
+
460
+ Implements adaptive regularization:
461
+
462
+ L_prior = λ * Σ_i w_i * ||u_i - μ_0||²
463
+
464
+ where:
465
+ w_i = 1 / max(n_i, min_samples) (adaptive weighting)
466
+ λ = regularization_strength
467
+
468
+ Participants with fewer samples get stronger regularization.
469
+ This prevents overfitting when participant has little data.
470
+
471
+ For multi-parameter random effects, sums over all parameters.
472
+
473
+ Returns
474
+ -------
475
+ torch.Tensor
476
+ Scalar regularization loss to add to training loss.
477
+
478
+ Examples
479
+ --------
480
+ >>> # During training:
481
+ >>> loss_data = cross_entropy(logits, labels)
482
+ >>> loss_prior = manager.compute_prior_loss()
483
+ >>> loss_total = loss_data + loss_prior
484
+ >>> loss_total.backward()
485
+ """
486
+ if self.config.mode == "fixed":
487
+ return torch.tensor(0.0)
488
+
489
+ loss = torch.tensor(0.0)
490
+
491
+ if self.config.mode == "random_intercepts":
492
+ # Iterate over all parameters (e.g., "mu", "cutpoint_1", "cutpoint_2")
493
+ for _param_name, param_dict in self.intercepts.items():
494
+ for participant_id, bias in param_dict.items():
495
+ # Deviation from prior mean
496
+ deviation = bias - self.config.prior_mean
497
+ squared_dev = torch.sum(deviation**2)
498
+
499
+ # Adaptive weight
500
+ if self.config.adaptive_regularization:
501
+ n_samples = self.participant_sample_counts.get(
502
+ participant_id, 1
503
+ )
504
+ weight = 1.0 / max(
505
+ n_samples, self.config.min_samples_for_random_effects
506
+ )
507
+ else:
508
+ weight = 1.0
509
+
510
+ loss += weight * squared_dev
511
+
512
+ elif self.config.mode == "random_slopes":
513
+ for participant_id, head in self.slopes.items():
514
+ # Sum squared parameters (deviation from 0)
515
+ squared_dev = sum(torch.sum(param**2) for param in head.parameters())
516
+
517
+ # Adaptive weight
518
+ if self.config.adaptive_regularization:
519
+ n_samples = self.participant_sample_counts.get(participant_id, 1)
520
+ weight = 1.0 / max(
521
+ n_samples, self.config.min_samples_for_random_effects
522
+ )
523
+ else:
524
+ weight = 1.0
525
+
526
+ loss += weight * squared_dev
527
+
528
+ return self.config.regularization_strength * loss
529
+
530
+ def save(self, path: Path) -> None:
531
+ """Save random effects to disk.
532
+
533
+ Parameters
534
+ ----------
535
+ path : Path
536
+ Directory to save random effects.
537
+ """
538
+ path.mkdir(parents=True, exist_ok=True)
539
+
540
+ # Save intercepts (nested dict)
541
+ if self.config.mode == "random_intercepts" and self.intercepts:
542
+ # Convert to CPU and detach
543
+ intercepts_cpu: dict[str, dict[str, torch.Tensor]] = {}
544
+ for param_name, param_dict in self.intercepts.items():
545
+ intercepts_cpu[param_name] = {
546
+ pid: tensor.detach().cpu() for pid, tensor in param_dict.items()
547
+ }
548
+ torch.save(intercepts_cpu, path / "intercepts.pt")
549
+
550
+ # Save slopes
551
+ if self.config.mode == "random_slopes" and self.slopes:
552
+ slopes_state = {pid: head.state_dict() for pid, head in self.slopes.items()}
553
+ torch.save(slopes_state, path / "slopes.pt")
554
+
555
+ # Save sample counts
556
+ with open(path / "sample_counts.json", "w") as f:
557
+ json.dump(self.participant_sample_counts, f)
558
+
559
+ # Save variance history (if any)
560
+ if self.variance_history:
561
+ # Serialize VarianceComponents to JSON
562
+ variance_history_data = [
563
+ vc.model_dump() if hasattr(vc, "model_dump") else vc
564
+ for vc in self.variance_history
565
+ ]
566
+ with open(path / "variance_history.json", "w") as f:
567
+ json.dump(variance_history_data, f, indent=2)
568
+
569
+ def load(self, path: Path, fixed_head: nn.Module | None = None) -> None:
570
+ """Load random effects from disk.
571
+
572
+ Parameters
573
+ ----------
574
+ path : Path
575
+ Directory to load from.
576
+ fixed_head : nn.Module | None
577
+ Fixed head (required if mode='random_slopes').
578
+
579
+ Raises
580
+ ------
581
+ FileNotFoundError
582
+ If path doesn't exist.
583
+ ValueError
584
+ If mode='random_slopes' but fixed_head is None.
585
+
586
+ Examples
587
+ --------
588
+ >>> manager.load(Path("model_checkpoint/random_effects"))
589
+ """
590
+ if not path.exists():
591
+ raise FileNotFoundError(f"Random effects directory not found: {path}")
592
+
593
+ # Load intercepts (nested dict)
594
+ if self.config.mode == "random_intercepts":
595
+ intercepts_path = path / "intercepts.pt"
596
+ if intercepts_path.exists():
597
+ self.intercepts = torch.load(intercepts_path, weights_only=False)
598
+
599
+ # Load slopes
600
+ if self.config.mode == "random_slopes":
601
+ if fixed_head is None:
602
+ raise ValueError(
603
+ "fixed_head is required when loading random slopes. "
604
+ "Pass the fixed effects head to load()."
605
+ )
606
+
607
+ slopes_path = path / "slopes.pt"
608
+ if slopes_path.exists():
609
+ slopes_state = torch.load(slopes_path, weights_only=False)
610
+ self.slopes = {}
611
+ for pid, state_dict in slopes_state.items():
612
+ head = copy.deepcopy(fixed_head)
613
+ head.load_state_dict(state_dict)
614
+ self.slopes[pid] = head
615
+
616
+ # Load sample counts
617
+ sample_counts_path = path / "sample_counts.json"
618
+ if sample_counts_path.exists():
619
+ with open(sample_counts_path) as f:
620
+ self.participant_sample_counts = json.load(f)
621
+
622
+ # Load variance history (if any)
623
+ variance_history_path = path / "variance_history.json"
624
+ if variance_history_path.exists():
625
+ with open(variance_history_path) as f:
626
+ variance_history_data = json.load(f)
627
+ # Deserialize VarianceComponents from JSON
628
+ from bead.active_learning.config import VarianceComponents # noqa: PLC0415
629
+
630
+ self.variance_history = [
631
+ VarianceComponents(**vc_data) if isinstance(vc_data, dict) else vc_data
632
+ for vc_data in variance_history_data
633
+ ]
634
+ # Restore variance_components from history
635
+ if self.variance_history:
636
+ last_vc = self.variance_history[-1]
637
+ # Infer param name from effect type for backwards compatibility
638
+ param_key = "slopes" if last_vc.effect_type == "slope" else "bias"
639
+ self.variance_components = {param_key: last_vc}