bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,639 @@
|
|
|
1
|
+
"""Manager for random effects in GLMM-based active learning.
|
|
2
|
+
|
|
3
|
+
Implements:
|
|
4
|
+
- Random effect storage and retrieval (intercepts and slopes)
|
|
5
|
+
- Variance component estimation (G matrix via MLE/REML)
|
|
6
|
+
- Empirical Bayes shrinkage for small groups
|
|
7
|
+
- Adaptive regularization based on sample counts
|
|
8
|
+
- Save/load with variance component history
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import copy
|
|
14
|
+
import json
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
|
|
22
|
+
from bead.active_learning.config import MixedEffectsConfig, VarianceComponents
|
|
23
|
+
|
|
24
|
+
__all__ = ["RandomEffectsManager"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RandomEffectsManager:
|
|
28
|
+
"""Manages random effects following GLMM theory: u ~ N(0, G).
|
|
29
|
+
|
|
30
|
+
Core responsibilities:
|
|
31
|
+
1. Store random effect values: u_i for each participant i
|
|
32
|
+
2. Estimate variance components: σ²_u (the G matrix)
|
|
33
|
+
3. Implement shrinkage: u_shrunk_i = λ_i * u_i + (1-λ_i) * μ_0
|
|
34
|
+
4. Compute prior loss: L_prior = λ * Σ_i w_i * ||u_i - μ_0||²
|
|
35
|
+
5. Handle unknown participants: Use population mean (μ_0)
|
|
36
|
+
|
|
37
|
+
Attributes
|
|
38
|
+
----------
|
|
39
|
+
config : MixedEffectsConfig
|
|
40
|
+
Configuration including mode, priors, regularization.
|
|
41
|
+
intercepts : dict[str, torch.Tensor]
|
|
42
|
+
Random intercepts per participant.
|
|
43
|
+
Key: participant_id, Value: bias vector of shape (n_classes,)
|
|
44
|
+
slopes : dict[str, nn.Module]
|
|
45
|
+
Random slopes per participant.
|
|
46
|
+
Key: participant_id, Value: model head (nn.Module)
|
|
47
|
+
participant_sample_counts : dict[str, int]
|
|
48
|
+
Training samples per participant (for adaptive regularization).
|
|
49
|
+
variance_components : VarianceComponents | None
|
|
50
|
+
Latest variance component estimates.
|
|
51
|
+
variance_history : list[VarianceComponents]
|
|
52
|
+
Variance components over training (for diagnostics).
|
|
53
|
+
|
|
54
|
+
Examples
|
|
55
|
+
--------
|
|
56
|
+
>>> config = MixedEffectsConfig(mode='random_intercepts')
|
|
57
|
+
>>> manager = RandomEffectsManager(config, n_classes=3)
|
|
58
|
+
|
|
59
|
+
>>> # Register participants during training
|
|
60
|
+
>>> manager.register_participant("alice", n_samples=10)
|
|
61
|
+
>>> manager.register_participant("bob", n_samples=15)
|
|
62
|
+
|
|
63
|
+
>>> # Get intercepts (creates if missing)
|
|
64
|
+
>>> bias_alice = manager.get_intercepts("alice", n_classes=3)
|
|
65
|
+
|
|
66
|
+
>>> # Estimate variance components after training
|
|
67
|
+
>>> var_comp = manager.estimate_variance_components()
|
|
68
|
+
>>> print(f"σ²_u = {var_comp.variance:.3f}")
|
|
69
|
+
|
|
70
|
+
>>> # Compute prior loss for regularization
|
|
71
|
+
>>> loss_prior = manager.compute_prior_loss()
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, config: MixedEffectsConfig, **kwargs: Any) -> None:
|
|
75
|
+
"""Initialize random effects manager.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
config : MixedEffectsConfig
|
|
80
|
+
GLMM configuration.
|
|
81
|
+
**kwargs : Any
|
|
82
|
+
Additional arguments (e.g., n_classes, hidden_dim).
|
|
83
|
+
Required arguments depend on mode.
|
|
84
|
+
|
|
85
|
+
Raises
|
|
86
|
+
------
|
|
87
|
+
ValueError
|
|
88
|
+
If mode='random_slopes' but required kwargs missing.
|
|
89
|
+
"""
|
|
90
|
+
self.config = config
|
|
91
|
+
# Nested dict structure: intercepts[param_name][participant_id] = tensor
|
|
92
|
+
# Examples:
|
|
93
|
+
# intercepts["mu"]["alice"] = tensor([0.12])
|
|
94
|
+
# intercepts["cutpoint_1"]["alice"] = tensor([0.05])
|
|
95
|
+
self.intercepts: dict[str, dict[str, torch.Tensor]] = {}
|
|
96
|
+
self.slopes: dict[str, nn.Module] = {}
|
|
97
|
+
self.participant_sample_counts: dict[str, int] = {}
|
|
98
|
+
|
|
99
|
+
self.variance_components: VarianceComponents | None = None
|
|
100
|
+
self.variance_history: list[VarianceComponents] = []
|
|
101
|
+
|
|
102
|
+
# Store kwargs for creating new random effects
|
|
103
|
+
self.creation_kwargs = kwargs
|
|
104
|
+
|
|
105
|
+
def register_participant(self, participant_id: str, n_samples: int) -> None:
|
|
106
|
+
"""Register participant and track sample count.
|
|
107
|
+
|
|
108
|
+
Used for:
|
|
109
|
+
- Adaptive regularization (fewer samples → stronger regularization)
|
|
110
|
+
- Shrinkage estimation (fewer samples → shrink toward mean)
|
|
111
|
+
- Variance component estimation
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
participant_id : str
|
|
116
|
+
Participant identifier.
|
|
117
|
+
n_samples : int
|
|
118
|
+
Number of samples for this participant.
|
|
119
|
+
|
|
120
|
+
Raises
|
|
121
|
+
------
|
|
122
|
+
ValueError
|
|
123
|
+
If participant_id empty or n_samples not positive.
|
|
124
|
+
|
|
125
|
+
Examples
|
|
126
|
+
--------
|
|
127
|
+
>>> manager.register_participant("alice", n_samples=10)
|
|
128
|
+
>>> manager.register_participant("bob", n_samples=15)
|
|
129
|
+
"""
|
|
130
|
+
if not participant_id:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
"participant_id cannot be empty. "
|
|
133
|
+
"Ensure all participants have valid string identifiers."
|
|
134
|
+
)
|
|
135
|
+
if n_samples <= 0:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"n_samples must be positive, got {n_samples}. "
|
|
138
|
+
f"Each participant must have at least 1 sample."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Accumulate samples if participant seen before
|
|
142
|
+
if participant_id in self.participant_sample_counts:
|
|
143
|
+
self.participant_sample_counts[participant_id] += n_samples
|
|
144
|
+
else:
|
|
145
|
+
self.participant_sample_counts[participant_id] = n_samples
|
|
146
|
+
|
|
147
|
+
def get_intercepts(
|
|
148
|
+
self,
|
|
149
|
+
participant_id: str,
|
|
150
|
+
n_classes: int,
|
|
151
|
+
param_name: str,
|
|
152
|
+
create_if_missing: bool = True,
|
|
153
|
+
) -> torch.Tensor:
|
|
154
|
+
"""Get random intercepts for specific distribution parameter.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
participant_id : str
|
|
159
|
+
Participant identifier.
|
|
160
|
+
n_classes : int
|
|
161
|
+
Number of classes (length of bias vector).
|
|
162
|
+
param_name : str
|
|
163
|
+
Name of the distribution parameter (e.g., "mu", "cutpoint_1", "cutpoint_2").
|
|
164
|
+
create_if_missing : bool, default=True
|
|
165
|
+
Whether to create new intercepts for unknown participants.
|
|
166
|
+
True: Training (create new random effects)
|
|
167
|
+
False: Prediction (use prior mean for unknown)
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
torch.Tensor
|
|
172
|
+
Bias vector of shape (n_classes,).
|
|
173
|
+
|
|
174
|
+
Raises
|
|
175
|
+
------
|
|
176
|
+
ValueError
|
|
177
|
+
If mode is not 'random_intercepts'.
|
|
178
|
+
|
|
179
|
+
Examples
|
|
180
|
+
--------
|
|
181
|
+
>>> bias = manager.get_intercepts("alice", n_classes=3, param_name="mu")
|
|
182
|
+
>>> bias.shape
|
|
183
|
+
torch.Size([3])
|
|
184
|
+
|
|
185
|
+
>>> # Multi-parameter: Ordered beta
|
|
186
|
+
>>> mu_bias = manager.get_intercepts("alice", 1, param_name="mu")
|
|
187
|
+
>>> c1_bias = manager.get_intercepts("alice", 1, param_name="cutpoint_1")
|
|
188
|
+
"""
|
|
189
|
+
if self.config.mode != "random_intercepts":
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"get_intercepts() called but mode is '{self.config.mode}', "
|
|
192
|
+
f"expected 'random_intercepts'. "
|
|
193
|
+
f"Use mode='random_intercepts' in MixedEffectsConfig."
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Initialize parameter dict if first time seeing this parameter
|
|
197
|
+
if param_name not in self.intercepts:
|
|
198
|
+
self.intercepts[param_name] = {}
|
|
199
|
+
|
|
200
|
+
param_dict = self.intercepts[param_name]
|
|
201
|
+
|
|
202
|
+
# Known participant: return learned intercepts
|
|
203
|
+
if participant_id in param_dict:
|
|
204
|
+
return param_dict[participant_id]
|
|
205
|
+
|
|
206
|
+
# Unknown participant: use prior mean
|
|
207
|
+
if not create_if_missing:
|
|
208
|
+
return torch.zeros(n_classes) + self.config.prior_mean
|
|
209
|
+
|
|
210
|
+
# Create new intercepts from prior: u_i ~ N(μ_0, σ²_0)
|
|
211
|
+
bias = (
|
|
212
|
+
torch.randn(n_classes) * np.sqrt(self.config.prior_variance)
|
|
213
|
+
+ self.config.prior_mean
|
|
214
|
+
)
|
|
215
|
+
bias.requires_grad = True
|
|
216
|
+
param_dict[participant_id] = bias
|
|
217
|
+
return bias
|
|
218
|
+
|
|
219
|
+
def get_intercepts_with_shrinkage(
|
|
220
|
+
self, participant_id: str, n_classes: int, param_name: str = "bias"
|
|
221
|
+
) -> torch.Tensor:
|
|
222
|
+
"""Get random intercepts with Empirical Bayes shrinkage.
|
|
223
|
+
|
|
224
|
+
Implements shrinkage toward population mean:
|
|
225
|
+
|
|
226
|
+
u_shrunk_i = λ_i * u_mle_i + (1 - λ_i) * μ_0
|
|
227
|
+
|
|
228
|
+
where:
|
|
229
|
+
λ_i = n_i / (n_i + k)
|
|
230
|
+
k ≈ σ²_ε / σ²_u (ratio of residual to random effect variance)
|
|
231
|
+
|
|
232
|
+
For participants with few samples, shrink toward μ_0 (population mean).
|
|
233
|
+
For participants with many samples, use their specific estimate.
|
|
234
|
+
|
|
235
|
+
Parameters
|
|
236
|
+
----------
|
|
237
|
+
participant_id : str
|
|
238
|
+
Participant identifier.
|
|
239
|
+
n_classes : int
|
|
240
|
+
Number of classes.
|
|
241
|
+
param_name : str, default="bias"
|
|
242
|
+
Name of the distribution parameter.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
torch.Tensor
|
|
247
|
+
Shrunk bias vector of shape (n_classes,).
|
|
248
|
+
|
|
249
|
+
Examples
|
|
250
|
+
--------
|
|
251
|
+
>>> # Participant with 2 samples → strong shrinkage
|
|
252
|
+
>>> manager.register_participant("alice", n_samples=2)
|
|
253
|
+
>>> bias_shrunk = manager.get_intercepts_with_shrinkage("alice", 3)
|
|
254
|
+
|
|
255
|
+
>>> # Participant with 100 samples → little shrinkage
|
|
256
|
+
>>> manager.register_participant("bob", n_samples=100)
|
|
257
|
+
>>> bias_shrunk_bob = manager.get_intercepts_with_shrinkage("bob", 3)
|
|
258
|
+
"""
|
|
259
|
+
if self.config.mode != "random_intercepts":
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"Shrinkage only for random_intercepts mode, got '{self.config.mode}'"
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Get MLE estimate (or prior if unknown)
|
|
265
|
+
u_mle = self.get_intercepts(
|
|
266
|
+
participant_id, n_classes, param_name, create_if_missing=False
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Unknown participant: return prior mean (no shrinkage needed)
|
|
270
|
+
param_dict = self.intercepts.get(param_name, {})
|
|
271
|
+
if participant_id not in param_dict:
|
|
272
|
+
return u_mle
|
|
273
|
+
|
|
274
|
+
# Compute shrinkage factor λ_i
|
|
275
|
+
n_i = self.participant_sample_counts.get(participant_id, 1)
|
|
276
|
+
|
|
277
|
+
# Estimate k from variance components if available
|
|
278
|
+
if self.variance_components is not None:
|
|
279
|
+
sigma2_u = self.variance_components.variance
|
|
280
|
+
# Estimate σ²_ε from residuals (simplified: assume σ²_ε ≈ 1)
|
|
281
|
+
sigma2_epsilon = 1.0
|
|
282
|
+
k = sigma2_epsilon / max(sigma2_u, 1e-6)
|
|
283
|
+
else:
|
|
284
|
+
# Fallback: use min_samples as proxy for k
|
|
285
|
+
k = self.config.min_samples_for_random_effects
|
|
286
|
+
|
|
287
|
+
lambda_i = n_i / (n_i + k)
|
|
288
|
+
|
|
289
|
+
# Shrinkage: u_shrunk = λ * u_mle + (1-λ) * μ_0
|
|
290
|
+
mu_0 = self.config.prior_mean
|
|
291
|
+
u_shrunk = lambda_i * u_mle + (1 - lambda_i) * mu_0
|
|
292
|
+
|
|
293
|
+
return u_shrunk
|
|
294
|
+
|
|
295
|
+
def get_slopes(
|
|
296
|
+
self,
|
|
297
|
+
participant_id: str,
|
|
298
|
+
fixed_head: nn.Module,
|
|
299
|
+
create_if_missing: bool = True,
|
|
300
|
+
) -> nn.Module:
|
|
301
|
+
"""Get random slopes (model head) for participant.
|
|
302
|
+
|
|
303
|
+
Behavior:
|
|
304
|
+
- Known participant: Return learned head
|
|
305
|
+
- Unknown participant:
|
|
306
|
+
- If create_if_missing=True: Clone fixed_head and add noise
|
|
307
|
+
- If create_if_missing=False: Return clone of fixed_head
|
|
308
|
+
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
participant_id : str
|
|
312
|
+
Participant identifier.
|
|
313
|
+
fixed_head : nn.Module
|
|
314
|
+
Fixed effects head to clone for new participants.
|
|
315
|
+
create_if_missing : bool, default=True
|
|
316
|
+
Whether to create new slopes for unknown participants.
|
|
317
|
+
|
|
318
|
+
Returns
|
|
319
|
+
-------
|
|
320
|
+
nn.Module
|
|
321
|
+
Model head for this participant.
|
|
322
|
+
|
|
323
|
+
Raises
|
|
324
|
+
------
|
|
325
|
+
ValueError
|
|
326
|
+
If mode is not 'random_slopes'.
|
|
327
|
+
|
|
328
|
+
Examples
|
|
329
|
+
--------
|
|
330
|
+
>>> fixed_head = nn.Linear(768, 3)
|
|
331
|
+
>>> # Training: Create participant-specific head
|
|
332
|
+
>>> head_alice = manager.get_slopes("alice", fixed_head, create_if_missing=True)
|
|
333
|
+
|
|
334
|
+
>>> # Prediction: Use fixed head for unknown
|
|
335
|
+
>>> head_unknown = manager.get_slopes(
|
|
336
|
+
... "unknown", fixed_head, create_if_missing=False
|
|
337
|
+
... )
|
|
338
|
+
"""
|
|
339
|
+
if self.config.mode != "random_slopes":
|
|
340
|
+
raise ValueError(
|
|
341
|
+
f"get_slopes() called but mode is '{self.config.mode}', "
|
|
342
|
+
f"expected 'random_slopes'"
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Known participant: return learned slopes
|
|
346
|
+
if participant_id in self.slopes:
|
|
347
|
+
return self.slopes[participant_id]
|
|
348
|
+
|
|
349
|
+
# Unknown participant: return clone of fixed head
|
|
350
|
+
if not create_if_missing:
|
|
351
|
+
return copy.deepcopy(fixed_head)
|
|
352
|
+
|
|
353
|
+
# Create new slopes: φ_i = θ + noise
|
|
354
|
+
# Clone fixed head and add Gaussian noise to parameters
|
|
355
|
+
participant_head = copy.deepcopy(fixed_head)
|
|
356
|
+
|
|
357
|
+
with torch.no_grad():
|
|
358
|
+
for param in participant_head.parameters():
|
|
359
|
+
noise = torch.randn_like(param) * np.sqrt(self.config.prior_variance)
|
|
360
|
+
param.add_(noise)
|
|
361
|
+
|
|
362
|
+
self.slopes[participant_id] = participant_head
|
|
363
|
+
return participant_head
|
|
364
|
+
|
|
365
|
+
def estimate_variance_components(
|
|
366
|
+
self,
|
|
367
|
+
) -> dict[str, VarianceComponents] | None:
|
|
368
|
+
"""Estimate variance components (G matrix) from random effects.
|
|
369
|
+
|
|
370
|
+
Returns
|
|
371
|
+
-------
|
|
372
|
+
dict[str, VarianceComponents] | None
|
|
373
|
+
Dictionary mapping param_name -> VarianceComponents.
|
|
374
|
+
For single-parameter models (most common), returns dict with one key.
|
|
375
|
+
For multi-parameter models (e.g., ordered beta), returns dict
|
|
376
|
+
with multiple keys.
|
|
377
|
+
Returns None if mode='fixed' or no random_slopes.
|
|
378
|
+
|
|
379
|
+
Examples
|
|
380
|
+
--------
|
|
381
|
+
>>> # Single parameter (most common)
|
|
382
|
+
>>> var_comps = manager.estimate_variance_components()
|
|
383
|
+
>>> print(f"Mu variance: {var_comps['mu'].variance:.3f}")
|
|
384
|
+
|
|
385
|
+
>>> # Multi-parameter (ordered beta)
|
|
386
|
+
>>> var_comps = manager.estimate_variance_components()
|
|
387
|
+
>>> print(f"Mu variance: {var_comps['mu'].variance:.3f}")
|
|
388
|
+
>>> print(f"Cutpoint_1 variance: {var_comps['cutpoint_1'].variance:.3f}")
|
|
389
|
+
"""
|
|
390
|
+
if self.config.mode == "fixed":
|
|
391
|
+
return None
|
|
392
|
+
|
|
393
|
+
if self.config.mode == "random_intercepts":
|
|
394
|
+
if not self.intercepts:
|
|
395
|
+
return None
|
|
396
|
+
|
|
397
|
+
variance_components: dict[str, VarianceComponents] = {}
|
|
398
|
+
for param_name, param_intercepts in self.intercepts.items():
|
|
399
|
+
if not param_intercepts:
|
|
400
|
+
continue
|
|
401
|
+
|
|
402
|
+
all_intercepts = torch.stack(list(param_intercepts.values()))
|
|
403
|
+
if len(param_intercepts) == 1:
|
|
404
|
+
variance = 0.0
|
|
405
|
+
else:
|
|
406
|
+
variance = torch.var(all_intercepts, unbiased=True).item()
|
|
407
|
+
|
|
408
|
+
variance_components[param_name] = VarianceComponents(
|
|
409
|
+
grouping_factor="participant",
|
|
410
|
+
effect_type="intercept",
|
|
411
|
+
variance=variance,
|
|
412
|
+
n_groups=len(param_intercepts),
|
|
413
|
+
n_observations_per_group=self.participant_sample_counts.copy(),
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# Update variance_components and history
|
|
417
|
+
self.variance_components = variance_components
|
|
418
|
+
# Store the first param's variance in history for backwards compatibility
|
|
419
|
+
first_param = next(iter(variance_components.values()))
|
|
420
|
+
self.variance_history.append(first_param)
|
|
421
|
+
|
|
422
|
+
return variance_components
|
|
423
|
+
|
|
424
|
+
elif self.config.mode == "random_slopes":
|
|
425
|
+
if not self.slopes:
|
|
426
|
+
return None
|
|
427
|
+
|
|
428
|
+
all_params: list[torch.Tensor] = []
|
|
429
|
+
for head in self.slopes.values():
|
|
430
|
+
params_flat = torch.cat([p.flatten() for p in head.parameters()])
|
|
431
|
+
all_params.append(params_flat)
|
|
432
|
+
|
|
433
|
+
all_params_tensor = torch.stack(all_params)
|
|
434
|
+
if len(self.slopes) == 1:
|
|
435
|
+
variance = 0.0
|
|
436
|
+
else:
|
|
437
|
+
variance = torch.var(all_params_tensor, unbiased=True).item()
|
|
438
|
+
|
|
439
|
+
# Random slopes still returns single variance component (not per-parameter)
|
|
440
|
+
slope_var_comp = VarianceComponents(
|
|
441
|
+
grouping_factor="participant",
|
|
442
|
+
effect_type="slope",
|
|
443
|
+
variance=variance,
|
|
444
|
+
n_groups=len(self.slopes),
|
|
445
|
+
n_observations_per_group=self.participant_sample_counts.copy(),
|
|
446
|
+
)
|
|
447
|
+
result = {"slopes": slope_var_comp}
|
|
448
|
+
|
|
449
|
+
# Update variance_components and history
|
|
450
|
+
self.variance_components = result
|
|
451
|
+
self.variance_history.append(slope_var_comp)
|
|
452
|
+
|
|
453
|
+
return result
|
|
454
|
+
|
|
455
|
+
return None
|
|
456
|
+
|
|
457
|
+
def compute_prior_loss(self) -> torch.Tensor:
|
|
458
|
+
"""Compute regularization loss toward prior.
|
|
459
|
+
|
|
460
|
+
Implements adaptive regularization:
|
|
461
|
+
|
|
462
|
+
L_prior = λ * Σ_i w_i * ||u_i - μ_0||²
|
|
463
|
+
|
|
464
|
+
where:
|
|
465
|
+
w_i = 1 / max(n_i, min_samples) (adaptive weighting)
|
|
466
|
+
λ = regularization_strength
|
|
467
|
+
|
|
468
|
+
Participants with fewer samples get stronger regularization.
|
|
469
|
+
This prevents overfitting when participant has little data.
|
|
470
|
+
|
|
471
|
+
For multi-parameter random effects, sums over all parameters.
|
|
472
|
+
|
|
473
|
+
Returns
|
|
474
|
+
-------
|
|
475
|
+
torch.Tensor
|
|
476
|
+
Scalar regularization loss to add to training loss.
|
|
477
|
+
|
|
478
|
+
Examples
|
|
479
|
+
--------
|
|
480
|
+
>>> # During training:
|
|
481
|
+
>>> loss_data = cross_entropy(logits, labels)
|
|
482
|
+
>>> loss_prior = manager.compute_prior_loss()
|
|
483
|
+
>>> loss_total = loss_data + loss_prior
|
|
484
|
+
>>> loss_total.backward()
|
|
485
|
+
"""
|
|
486
|
+
if self.config.mode == "fixed":
|
|
487
|
+
return torch.tensor(0.0)
|
|
488
|
+
|
|
489
|
+
loss = torch.tensor(0.0)
|
|
490
|
+
|
|
491
|
+
if self.config.mode == "random_intercepts":
|
|
492
|
+
# Iterate over all parameters (e.g., "mu", "cutpoint_1", "cutpoint_2")
|
|
493
|
+
for _param_name, param_dict in self.intercepts.items():
|
|
494
|
+
for participant_id, bias in param_dict.items():
|
|
495
|
+
# Deviation from prior mean
|
|
496
|
+
deviation = bias - self.config.prior_mean
|
|
497
|
+
squared_dev = torch.sum(deviation**2)
|
|
498
|
+
|
|
499
|
+
# Adaptive weight
|
|
500
|
+
if self.config.adaptive_regularization:
|
|
501
|
+
n_samples = self.participant_sample_counts.get(
|
|
502
|
+
participant_id, 1
|
|
503
|
+
)
|
|
504
|
+
weight = 1.0 / max(
|
|
505
|
+
n_samples, self.config.min_samples_for_random_effects
|
|
506
|
+
)
|
|
507
|
+
else:
|
|
508
|
+
weight = 1.0
|
|
509
|
+
|
|
510
|
+
loss += weight * squared_dev
|
|
511
|
+
|
|
512
|
+
elif self.config.mode == "random_slopes":
|
|
513
|
+
for participant_id, head in self.slopes.items():
|
|
514
|
+
# Sum squared parameters (deviation from 0)
|
|
515
|
+
squared_dev = sum(torch.sum(param**2) for param in head.parameters())
|
|
516
|
+
|
|
517
|
+
# Adaptive weight
|
|
518
|
+
if self.config.adaptive_regularization:
|
|
519
|
+
n_samples = self.participant_sample_counts.get(participant_id, 1)
|
|
520
|
+
weight = 1.0 / max(
|
|
521
|
+
n_samples, self.config.min_samples_for_random_effects
|
|
522
|
+
)
|
|
523
|
+
else:
|
|
524
|
+
weight = 1.0
|
|
525
|
+
|
|
526
|
+
loss += weight * squared_dev
|
|
527
|
+
|
|
528
|
+
return self.config.regularization_strength * loss
|
|
529
|
+
|
|
530
|
+
def save(self, path: Path) -> None:
|
|
531
|
+
"""Save random effects to disk.
|
|
532
|
+
|
|
533
|
+
Parameters
|
|
534
|
+
----------
|
|
535
|
+
path : Path
|
|
536
|
+
Directory to save random effects.
|
|
537
|
+
"""
|
|
538
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
539
|
+
|
|
540
|
+
# Save intercepts (nested dict)
|
|
541
|
+
if self.config.mode == "random_intercepts" and self.intercepts:
|
|
542
|
+
# Convert to CPU and detach
|
|
543
|
+
intercepts_cpu: dict[str, dict[str, torch.Tensor]] = {}
|
|
544
|
+
for param_name, param_dict in self.intercepts.items():
|
|
545
|
+
intercepts_cpu[param_name] = {
|
|
546
|
+
pid: tensor.detach().cpu() for pid, tensor in param_dict.items()
|
|
547
|
+
}
|
|
548
|
+
torch.save(intercepts_cpu, path / "intercepts.pt")
|
|
549
|
+
|
|
550
|
+
# Save slopes
|
|
551
|
+
if self.config.mode == "random_slopes" and self.slopes:
|
|
552
|
+
slopes_state = {pid: head.state_dict() for pid, head in self.slopes.items()}
|
|
553
|
+
torch.save(slopes_state, path / "slopes.pt")
|
|
554
|
+
|
|
555
|
+
# Save sample counts
|
|
556
|
+
with open(path / "sample_counts.json", "w") as f:
|
|
557
|
+
json.dump(self.participant_sample_counts, f)
|
|
558
|
+
|
|
559
|
+
# Save variance history (if any)
|
|
560
|
+
if self.variance_history:
|
|
561
|
+
# Serialize VarianceComponents to JSON
|
|
562
|
+
variance_history_data = [
|
|
563
|
+
vc.model_dump() if hasattr(vc, "model_dump") else vc
|
|
564
|
+
for vc in self.variance_history
|
|
565
|
+
]
|
|
566
|
+
with open(path / "variance_history.json", "w") as f:
|
|
567
|
+
json.dump(variance_history_data, f, indent=2)
|
|
568
|
+
|
|
569
|
+
def load(self, path: Path, fixed_head: nn.Module | None = None) -> None:
|
|
570
|
+
"""Load random effects from disk.
|
|
571
|
+
|
|
572
|
+
Parameters
|
|
573
|
+
----------
|
|
574
|
+
path : Path
|
|
575
|
+
Directory to load from.
|
|
576
|
+
fixed_head : nn.Module | None
|
|
577
|
+
Fixed head (required if mode='random_slopes').
|
|
578
|
+
|
|
579
|
+
Raises
|
|
580
|
+
------
|
|
581
|
+
FileNotFoundError
|
|
582
|
+
If path doesn't exist.
|
|
583
|
+
ValueError
|
|
584
|
+
If mode='random_slopes' but fixed_head is None.
|
|
585
|
+
|
|
586
|
+
Examples
|
|
587
|
+
--------
|
|
588
|
+
>>> manager.load(Path("model_checkpoint/random_effects"))
|
|
589
|
+
"""
|
|
590
|
+
if not path.exists():
|
|
591
|
+
raise FileNotFoundError(f"Random effects directory not found: {path}")
|
|
592
|
+
|
|
593
|
+
# Load intercepts (nested dict)
|
|
594
|
+
if self.config.mode == "random_intercepts":
|
|
595
|
+
intercepts_path = path / "intercepts.pt"
|
|
596
|
+
if intercepts_path.exists():
|
|
597
|
+
self.intercepts = torch.load(intercepts_path, weights_only=False)
|
|
598
|
+
|
|
599
|
+
# Load slopes
|
|
600
|
+
if self.config.mode == "random_slopes":
|
|
601
|
+
if fixed_head is None:
|
|
602
|
+
raise ValueError(
|
|
603
|
+
"fixed_head is required when loading random slopes. "
|
|
604
|
+
"Pass the fixed effects head to load()."
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
slopes_path = path / "slopes.pt"
|
|
608
|
+
if slopes_path.exists():
|
|
609
|
+
slopes_state = torch.load(slopes_path, weights_only=False)
|
|
610
|
+
self.slopes = {}
|
|
611
|
+
for pid, state_dict in slopes_state.items():
|
|
612
|
+
head = copy.deepcopy(fixed_head)
|
|
613
|
+
head.load_state_dict(state_dict)
|
|
614
|
+
self.slopes[pid] = head
|
|
615
|
+
|
|
616
|
+
# Load sample counts
|
|
617
|
+
sample_counts_path = path / "sample_counts.json"
|
|
618
|
+
if sample_counts_path.exists():
|
|
619
|
+
with open(sample_counts_path) as f:
|
|
620
|
+
self.participant_sample_counts = json.load(f)
|
|
621
|
+
|
|
622
|
+
# Load variance history (if any)
|
|
623
|
+
variance_history_path = path / "variance_history.json"
|
|
624
|
+
if variance_history_path.exists():
|
|
625
|
+
with open(variance_history_path) as f:
|
|
626
|
+
variance_history_data = json.load(f)
|
|
627
|
+
# Deserialize VarianceComponents from JSON
|
|
628
|
+
from bead.active_learning.config import VarianceComponents # noqa: PLC0415
|
|
629
|
+
|
|
630
|
+
self.variance_history = [
|
|
631
|
+
VarianceComponents(**vc_data) if isinstance(vc_data, dict) else vc_data
|
|
632
|
+
for vc_data in variance_history_data
|
|
633
|
+
]
|
|
634
|
+
# Restore variance_components from history
|
|
635
|
+
if self.variance_history:
|
|
636
|
+
last_vc = self.variance_history[-1]
|
|
637
|
+
# Infer param name from effect type for backwards compatibility
|
|
638
|
+
param_key = "slopes" if last_vc.effect_type == "slope" else "bias"
|
|
639
|
+
self.variance_components = {param_key: last_vc}
|