bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,852 @@
|
|
|
1
|
+
"""Base interfaces for active learning models with mixed effects support.
|
|
2
|
+
|
|
3
|
+
This module implements Generalized Linear Mixed Effects Models (GLMMs) following
|
|
4
|
+
the standard formulation:
|
|
5
|
+
|
|
6
|
+
y = Xβ + Zu + ε
|
|
7
|
+
|
|
8
|
+
Where:
|
|
9
|
+
- Xβ: Fixed effects (population-level parameters, shared across all groups)
|
|
10
|
+
- Zu: Random effects (group-specific parameters, e.g., per-participant)
|
|
11
|
+
- u ~ N(0, G): Random effects with variance-covariance matrix G
|
|
12
|
+
- ε: Residuals
|
|
13
|
+
|
|
14
|
+
The implementation supports three modeling modes:
|
|
15
|
+
1. Fixed effects: Standard model, ignores grouping structure
|
|
16
|
+
2. Random intercepts: Per-group biases (Zu = bias vector per group)
|
|
17
|
+
3. Random slopes: Per-group model parameters (Zu = separate model head per group)
|
|
18
|
+
|
|
19
|
+
References
|
|
20
|
+
----------
|
|
21
|
+
- Bates et al. (2015). "Fitting Linear Mixed-Effects Models using lme4"
|
|
22
|
+
- Simchoni & Rosset (2022). "Integrating Random Effects in Deep Neural Networks"
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import json
|
|
28
|
+
from abc import ABC, abstractmethod
|
|
29
|
+
from collections import Counter
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
from typing import TYPE_CHECKING
|
|
32
|
+
|
|
33
|
+
import numpy as np
|
|
34
|
+
|
|
35
|
+
from bead.active_learning.config import (
|
|
36
|
+
MixedEffectsConfig,
|
|
37
|
+
RandomEffectsSpec,
|
|
38
|
+
VarianceComponents,
|
|
39
|
+
)
|
|
40
|
+
from bead.data.base import BeadBaseModel
|
|
41
|
+
from bead.items.item import Item
|
|
42
|
+
|
|
43
|
+
if TYPE_CHECKING:
|
|
44
|
+
import torch
|
|
45
|
+
|
|
46
|
+
from bead.items.item_template import ItemTemplate, TaskType
|
|
47
|
+
|
|
48
|
+
__all__ = [
|
|
49
|
+
"ActiveLearningModel",
|
|
50
|
+
"ModelPrediction",
|
|
51
|
+
"MixedEffectsConfig",
|
|
52
|
+
"VarianceComponents",
|
|
53
|
+
"RandomEffectsSpec",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ModelPrediction(BeadBaseModel):
|
|
58
|
+
"""Prediction output for a single item.
|
|
59
|
+
|
|
60
|
+
Attributes
|
|
61
|
+
----------
|
|
62
|
+
item_id : str
|
|
63
|
+
Unique identifier for the item.
|
|
64
|
+
probabilities : dict[str, float]
|
|
65
|
+
Predicted probabilities for each class/option.
|
|
66
|
+
Keys are option names (e.g., "option_a", "option_b") or class labels.
|
|
67
|
+
predicted_class : str
|
|
68
|
+
The predicted class/option with highest probability.
|
|
69
|
+
confidence : float
|
|
70
|
+
Confidence score (max probability).
|
|
71
|
+
|
|
72
|
+
Examples
|
|
73
|
+
--------
|
|
74
|
+
>>> prediction = ModelPrediction(
|
|
75
|
+
... item_id="abc123",
|
|
76
|
+
... probabilities={"option_a": 0.7, "option_b": 0.3},
|
|
77
|
+
... predicted_class="option_a",
|
|
78
|
+
... confidence=0.7
|
|
79
|
+
... )
|
|
80
|
+
>>> prediction.predicted_class
|
|
81
|
+
'option_a'
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
item_id: str
|
|
85
|
+
probabilities: dict[str, float]
|
|
86
|
+
predicted_class: str
|
|
87
|
+
confidence: float
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ActiveLearningModel(ABC):
|
|
91
|
+
"""Base class for all active learning models with mixed effects support.
|
|
92
|
+
|
|
93
|
+
Implements GLMM-based active learning: y = Xβ + Zu + ε
|
|
94
|
+
|
|
95
|
+
All models must:
|
|
96
|
+
1. Support mixed effects (fixed, random_intercepts, random_slopes modes)
|
|
97
|
+
2. Accept participant_ids in train/predict/predict_proba (None for fixed effects)
|
|
98
|
+
3. Validate items match supported task types
|
|
99
|
+
4. Track variance components (if estimate_variance_components=True)
|
|
100
|
+
|
|
101
|
+
Attributes
|
|
102
|
+
----------
|
|
103
|
+
config : dict[str, str | int | float | bool | None] | BeadBaseModel
|
|
104
|
+
Model configuration (task-type-specific).
|
|
105
|
+
Must include a `mixed_effects: MixedEffectsConfig` field.
|
|
106
|
+
supported_task_types : list[TaskType]
|
|
107
|
+
List of task types this model can handle.
|
|
108
|
+
|
|
109
|
+
Examples
|
|
110
|
+
--------
|
|
111
|
+
>>> class MyModel(ActiveLearningModel):
|
|
112
|
+
... def __init__(self, config):
|
|
113
|
+
... super().__init__(config) # Validates mixed_effects field
|
|
114
|
+
... @property
|
|
115
|
+
... def supported_task_types(self):
|
|
116
|
+
... return ["forced_choice"]
|
|
117
|
+
... def validate_item_compatibility(self, item, item_template):
|
|
118
|
+
... pass
|
|
119
|
+
... def train(self, items, labels, participant_ids):
|
|
120
|
+
... return {}
|
|
121
|
+
... def predict(self, items, participant_ids):
|
|
122
|
+
... return []
|
|
123
|
+
... def predict_proba(self, items, participant_ids):
|
|
124
|
+
... return np.array([])
|
|
125
|
+
... def save(self, path):
|
|
126
|
+
... pass
|
|
127
|
+
... def load(self, path):
|
|
128
|
+
... pass
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self, config: dict[str, str | int | float | bool | None] | BeadBaseModel
|
|
133
|
+
) -> None:
|
|
134
|
+
"""Initialize model with configuration.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
config : Any
|
|
139
|
+
Model configuration. Must have a `mixed_effects` field of type
|
|
140
|
+
MixedEffectsConfig.
|
|
141
|
+
|
|
142
|
+
Raises
|
|
143
|
+
------
|
|
144
|
+
ValueError
|
|
145
|
+
If config is invalid or missing required fields.
|
|
146
|
+
|
|
147
|
+
Examples
|
|
148
|
+
--------
|
|
149
|
+
>>> from bead.config.active_learning import ForcedChoiceModelConfig
|
|
150
|
+
>>> config = ForcedChoiceModelConfig(
|
|
151
|
+
... n_classes=2,
|
|
152
|
+
... mixed_effects=MixedEffectsConfig(mode='fixed')
|
|
153
|
+
... )
|
|
154
|
+
>>> model = ForcedChoiceModel(config) # doctest: +SKIP
|
|
155
|
+
"""
|
|
156
|
+
self.config = config
|
|
157
|
+
|
|
158
|
+
# Validate mixed_effects field exists
|
|
159
|
+
if not hasattr(config, "mixed_effects"):
|
|
160
|
+
raise ValueError(
|
|
161
|
+
f"Model config must have a 'mixed_effects' field of type "
|
|
162
|
+
f"MixedEffectsConfig, but {type(config).__name__} has no such field. "
|
|
163
|
+
f"Add: mixed_effects: MixedEffectsConfig = "
|
|
164
|
+
f"Field(default_factory=MixedEffectsConfig)"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Validate mixed_effects is correct type
|
|
168
|
+
if not isinstance(config.mixed_effects, MixedEffectsConfig):
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"config.mixed_effects must be MixedEffectsConfig, but got "
|
|
171
|
+
f"{type(config.mixed_effects).__name__}. "
|
|
172
|
+
f"Ensure the field is properly typed: mixed_effects: MixedEffectsConfig"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _validate_items_labels_length(
|
|
176
|
+
self, items: list[Item], labels: list[str]
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Validate that items and labels have the same length.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
items : list[Item]
|
|
183
|
+
Training items.
|
|
184
|
+
labels : list[str]
|
|
185
|
+
Training labels.
|
|
186
|
+
|
|
187
|
+
Raises
|
|
188
|
+
------
|
|
189
|
+
ValueError
|
|
190
|
+
If items and labels have different lengths.
|
|
191
|
+
"""
|
|
192
|
+
if len(items) != len(labels):
|
|
193
|
+
raise ValueError(
|
|
194
|
+
f"Number of items ({len(items)}) must match "
|
|
195
|
+
f"number of labels ({len(labels)})"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def _validate_participant_ids_required(
|
|
199
|
+
self, participant_ids: list[str] | None, mode: str
|
|
200
|
+
) -> None:
|
|
201
|
+
"""Validate that participant_ids is provided when required.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
participant_ids : list[str] | None
|
|
206
|
+
Participant IDs to validate.
|
|
207
|
+
mode : str
|
|
208
|
+
Mixed effects mode ('fixed', 'random_intercepts', 'random_slopes').
|
|
209
|
+
|
|
210
|
+
Raises
|
|
211
|
+
------
|
|
212
|
+
ValueError
|
|
213
|
+
If participant_ids is None when mode requires it.
|
|
214
|
+
"""
|
|
215
|
+
if participant_ids is None and mode != "fixed":
|
|
216
|
+
raise ValueError(
|
|
217
|
+
f"participant_ids is required when mode='{mode}'. "
|
|
218
|
+
f"For fixed effects, set mode='fixed' in config. "
|
|
219
|
+
f"For mixed effects, provide participant_ids as list[str]."
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def _validate_participant_ids_length(
|
|
223
|
+
self, items: list[Item], participant_ids: list[str]
|
|
224
|
+
) -> None:
|
|
225
|
+
"""Validate that items and participant_ids have the same length.
|
|
226
|
+
|
|
227
|
+
Parameters
|
|
228
|
+
----------
|
|
229
|
+
items : list[Item]
|
|
230
|
+
Training items.
|
|
231
|
+
participant_ids : list[str]
|
|
232
|
+
Participant IDs.
|
|
233
|
+
|
|
234
|
+
Raises
|
|
235
|
+
------
|
|
236
|
+
ValueError
|
|
237
|
+
If items and participant_ids have different lengths.
|
|
238
|
+
"""
|
|
239
|
+
if len(items) != len(participant_ids):
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Length mismatch: {len(items)} items != {len(participant_ids)} "
|
|
242
|
+
f"participant_ids. participant_ids must have same length as items."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def _validate_participant_ids_not_empty(self, participant_ids: list[str]) -> None:
|
|
246
|
+
"""Validate that participant_ids does not contain empty strings.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
participant_ids : list[str]
|
|
251
|
+
Participant IDs to validate.
|
|
252
|
+
|
|
253
|
+
Raises
|
|
254
|
+
------
|
|
255
|
+
ValueError
|
|
256
|
+
If participant_ids contains empty strings.
|
|
257
|
+
"""
|
|
258
|
+
if any(not pid for pid in participant_ids):
|
|
259
|
+
raise ValueError(
|
|
260
|
+
"participant_ids cannot contain empty strings. "
|
|
261
|
+
"Ensure all participants have valid identifiers."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
def _normalize_participant_ids(
|
|
265
|
+
self,
|
|
266
|
+
participant_ids: list[str] | None,
|
|
267
|
+
items: list[Item],
|
|
268
|
+
mode: str,
|
|
269
|
+
) -> list[str]:
|
|
270
|
+
"""Normalize participant_ids based on mode.
|
|
271
|
+
|
|
272
|
+
For fixed mode, replaces participant_ids with dummy values.
|
|
273
|
+
For mixed effects modes, validates and returns participant_ids as-is.
|
|
274
|
+
|
|
275
|
+
Parameters
|
|
276
|
+
----------
|
|
277
|
+
participant_ids : list[str] | None
|
|
278
|
+
Participant IDs (may be None for fixed mode).
|
|
279
|
+
items : list[Item]
|
|
280
|
+
Training items (used to determine length).
|
|
281
|
+
mode : str
|
|
282
|
+
Mixed effects mode ('fixed', 'random_intercepts', 'random_slopes').
|
|
283
|
+
|
|
284
|
+
Returns
|
|
285
|
+
-------
|
|
286
|
+
list[str]
|
|
287
|
+
Normalized participant IDs (all "_fixed_" for fixed mode).
|
|
288
|
+
|
|
289
|
+
Raises
|
|
290
|
+
------
|
|
291
|
+
ValueError
|
|
292
|
+
If participant_ids is None when mode requires it.
|
|
293
|
+
ValueError
|
|
294
|
+
If items and participant_ids have different lengths.
|
|
295
|
+
ValueError
|
|
296
|
+
If participant_ids contains empty strings.
|
|
297
|
+
"""
|
|
298
|
+
import warnings # noqa: PLC0415
|
|
299
|
+
|
|
300
|
+
if participant_ids is None:
|
|
301
|
+
if mode != "fixed":
|
|
302
|
+
self._validate_participant_ids_required(participant_ids, mode)
|
|
303
|
+
return ["_fixed_"] * len(items)
|
|
304
|
+
|
|
305
|
+
# Validate length and empty strings before normalizing
|
|
306
|
+
self._validate_participant_ids_length(items, participant_ids)
|
|
307
|
+
self._validate_participant_ids_not_empty(participant_ids)
|
|
308
|
+
|
|
309
|
+
if mode == "fixed":
|
|
310
|
+
warnings.warn(
|
|
311
|
+
"participant_ids provided but mode='fixed'. "
|
|
312
|
+
"Participant IDs will be ignored.",
|
|
313
|
+
UserWarning,
|
|
314
|
+
stacklevel=3,
|
|
315
|
+
)
|
|
316
|
+
return ["_fixed_"] * len(items)
|
|
317
|
+
|
|
318
|
+
return participant_ids
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
@abstractmethod
|
|
322
|
+
def supported_task_types(self) -> list[TaskType]:
|
|
323
|
+
"""Get list of task types this model supports.
|
|
324
|
+
|
|
325
|
+
Returns
|
|
326
|
+
-------
|
|
327
|
+
list[TaskType]
|
|
328
|
+
List of supported TaskType literals from items.models.
|
|
329
|
+
|
|
330
|
+
Examples
|
|
331
|
+
--------
|
|
332
|
+
>>> model.supported_task_types
|
|
333
|
+
['forced_choice']
|
|
334
|
+
"""
|
|
335
|
+
pass
|
|
336
|
+
|
|
337
|
+
@abstractmethod
|
|
338
|
+
def validate_item_compatibility(
|
|
339
|
+
self, item: Item, item_template: ItemTemplate
|
|
340
|
+
) -> None:
|
|
341
|
+
"""Validate that an item is compatible with this model.
|
|
342
|
+
|
|
343
|
+
Parameters
|
|
344
|
+
----------
|
|
345
|
+
item : Item
|
|
346
|
+
Item to validate.
|
|
347
|
+
item_template : ItemTemplate
|
|
348
|
+
Template the item was constructed from.
|
|
349
|
+
|
|
350
|
+
Raises
|
|
351
|
+
------
|
|
352
|
+
ValueError
|
|
353
|
+
If item's task_type is not in supported_task_types.
|
|
354
|
+
ValueError
|
|
355
|
+
If item is missing required elements.
|
|
356
|
+
ValueError
|
|
357
|
+
If item structure is incompatible with model.
|
|
358
|
+
|
|
359
|
+
Examples
|
|
360
|
+
--------
|
|
361
|
+
>>> model.validate_item_compatibility(item, template) # doctest: +SKIP
|
|
362
|
+
"""
|
|
363
|
+
pass
|
|
364
|
+
|
|
365
|
+
# Hook methods for model-specific implementations
|
|
366
|
+
@abstractmethod
|
|
367
|
+
def _prepare_training_data(
|
|
368
|
+
self,
|
|
369
|
+
items: list[Item],
|
|
370
|
+
labels: list[str],
|
|
371
|
+
participant_ids: list[str],
|
|
372
|
+
validation_items: list[Item] | None,
|
|
373
|
+
validation_labels: list[str] | None,
|
|
374
|
+
) -> tuple[list[Item], list, list[str], list[Item] | None, list | None]:
|
|
375
|
+
"""Prepare training data for model-specific training.
|
|
376
|
+
|
|
377
|
+
Parameters
|
|
378
|
+
----------
|
|
379
|
+
items : list[Item]
|
|
380
|
+
Training items.
|
|
381
|
+
labels : list[str]
|
|
382
|
+
Training labels.
|
|
383
|
+
participant_ids : list[str]
|
|
384
|
+
Normalized participant IDs.
|
|
385
|
+
validation_items : list[Item] | None
|
|
386
|
+
Validation items.
|
|
387
|
+
validation_labels : list[str] | None
|
|
388
|
+
Validation labels.
|
|
389
|
+
|
|
390
|
+
Returns
|
|
391
|
+
-------
|
|
392
|
+
tuple[list[Item], list, list[str], list[Item] | None, list | None]
|
|
393
|
+
Items, labels, participant_ids, val_items, val_labels.
|
|
394
|
+
"""
|
|
395
|
+
pass
|
|
396
|
+
|
|
397
|
+
@abstractmethod
|
|
398
|
+
def _initialize_random_effects(self, n_classes: int) -> None:
|
|
399
|
+
"""Initialize random effects manager.
|
|
400
|
+
|
|
401
|
+
Parameters
|
|
402
|
+
----------
|
|
403
|
+
n_classes : int
|
|
404
|
+
Number of classes for random effects.
|
|
405
|
+
"""
|
|
406
|
+
pass
|
|
407
|
+
|
|
408
|
+
@abstractmethod
|
|
409
|
+
def _do_training(
|
|
410
|
+
self,
|
|
411
|
+
items: list[Item],
|
|
412
|
+
labels_numeric: list,
|
|
413
|
+
participant_ids: list[str],
|
|
414
|
+
validation_items: list[Item] | None,
|
|
415
|
+
validation_labels_numeric: list | None,
|
|
416
|
+
) -> dict[str, float]:
|
|
417
|
+
"""Perform model-specific training.
|
|
418
|
+
|
|
419
|
+
Parameters
|
|
420
|
+
----------
|
|
421
|
+
items : list[Item]
|
|
422
|
+
Training items.
|
|
423
|
+
labels_numeric : list
|
|
424
|
+
Numeric labels (format depends on model).
|
|
425
|
+
participant_ids : list[str]
|
|
426
|
+
Participant IDs.
|
|
427
|
+
validation_items : list[Item] | None
|
|
428
|
+
Validation items.
|
|
429
|
+
validation_labels_numeric : list | None
|
|
430
|
+
Numeric validation labels.
|
|
431
|
+
|
|
432
|
+
Returns
|
|
433
|
+
-------
|
|
434
|
+
dict[str, float]
|
|
435
|
+
Training metrics.
|
|
436
|
+
"""
|
|
437
|
+
pass
|
|
438
|
+
|
|
439
|
+
@abstractmethod
|
|
440
|
+
def _do_predict(
|
|
441
|
+
self, items: list[Item], participant_ids: list[str]
|
|
442
|
+
) -> list[ModelPrediction]:
|
|
443
|
+
"""Perform model-specific prediction.
|
|
444
|
+
|
|
445
|
+
Parameters
|
|
446
|
+
----------
|
|
447
|
+
items : list[Item]
|
|
448
|
+
Items to predict.
|
|
449
|
+
participant_ids : list[str]
|
|
450
|
+
Normalized participant IDs.
|
|
451
|
+
|
|
452
|
+
Returns
|
|
453
|
+
-------
|
|
454
|
+
list[ModelPrediction]
|
|
455
|
+
Predictions.
|
|
456
|
+
"""
|
|
457
|
+
pass
|
|
458
|
+
|
|
459
|
+
@abstractmethod
|
|
460
|
+
def _do_predict_proba(
|
|
461
|
+
self, items: list[Item], participant_ids: list[str]
|
|
462
|
+
) -> np.ndarray:
|
|
463
|
+
"""Perform model-specific probability prediction.
|
|
464
|
+
|
|
465
|
+
Parameters
|
|
466
|
+
----------
|
|
467
|
+
items : list[Item]
|
|
468
|
+
Items to predict.
|
|
469
|
+
participant_ids : list[str]
|
|
470
|
+
Normalized participant IDs.
|
|
471
|
+
|
|
472
|
+
Returns
|
|
473
|
+
-------
|
|
474
|
+
np.ndarray
|
|
475
|
+
Probability array.
|
|
476
|
+
"""
|
|
477
|
+
pass
|
|
478
|
+
|
|
479
|
+
@abstractmethod
|
|
480
|
+
def _get_save_state(self) -> dict[str, object]:
|
|
481
|
+
"""Get model-specific state to save.
|
|
482
|
+
|
|
483
|
+
Returns
|
|
484
|
+
-------
|
|
485
|
+
dict[str, object]
|
|
486
|
+
State dictionary to include in config.json.
|
|
487
|
+
"""
|
|
488
|
+
pass
|
|
489
|
+
|
|
490
|
+
@abstractmethod
|
|
491
|
+
def _save_model_components(self, save_path: Path) -> None:
|
|
492
|
+
"""Save model-specific components (encoder, head, etc.).
|
|
493
|
+
|
|
494
|
+
Parameters
|
|
495
|
+
----------
|
|
496
|
+
save_path : Path
|
|
497
|
+
Directory to save to.
|
|
498
|
+
"""
|
|
499
|
+
pass
|
|
500
|
+
|
|
501
|
+
@abstractmethod
|
|
502
|
+
def _load_model_components(self, load_path: Path) -> None:
|
|
503
|
+
"""Load model-specific components.
|
|
504
|
+
|
|
505
|
+
Parameters
|
|
506
|
+
----------
|
|
507
|
+
load_path : Path
|
|
508
|
+
Directory to load from.
|
|
509
|
+
"""
|
|
510
|
+
pass
|
|
511
|
+
|
|
512
|
+
@abstractmethod
|
|
513
|
+
def _restore_training_state(self, config_dict: dict[str, object]) -> None:
|
|
514
|
+
"""Restore model-specific training state.
|
|
515
|
+
|
|
516
|
+
Parameters
|
|
517
|
+
----------
|
|
518
|
+
config_dict : dict[str, object]
|
|
519
|
+
Configuration dictionary with training state.
|
|
520
|
+
"""
|
|
521
|
+
pass
|
|
522
|
+
|
|
523
|
+
@abstractmethod
|
|
524
|
+
def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
|
|
525
|
+
"""Get fixed head for random effects loading.
|
|
526
|
+
|
|
527
|
+
Returns
|
|
528
|
+
-------
|
|
529
|
+
nn.Module | None
|
|
530
|
+
Fixed head module, or None if not applicable.
|
|
531
|
+
"""
|
|
532
|
+
pass
|
|
533
|
+
|
|
534
|
+
@abstractmethod
|
|
535
|
+
def _get_n_classes_for_random_effects(self) -> int:
|
|
536
|
+
"""Get number of classes for random effects initialization.
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
-------
|
|
540
|
+
int
|
|
541
|
+
Number of classes.
|
|
542
|
+
"""
|
|
543
|
+
pass
|
|
544
|
+
|
|
545
|
+
# Common implementations
|
|
546
|
+
def train(
|
|
547
|
+
self,
|
|
548
|
+
items: list[Item],
|
|
549
|
+
labels: list[str] | list[list[str]],
|
|
550
|
+
participant_ids: list[str] | None = None,
|
|
551
|
+
validation_items: list[Item] | None = None,
|
|
552
|
+
validation_labels: list[str] | list[list[str]] | None = None,
|
|
553
|
+
) -> dict[str, float]:
|
|
554
|
+
"""Train model on labeled items with participant identifiers.
|
|
555
|
+
|
|
556
|
+
Parameters
|
|
557
|
+
----------
|
|
558
|
+
items : list[Item]
|
|
559
|
+
Training items.
|
|
560
|
+
labels : list[str]
|
|
561
|
+
Training labels (format depends on task type).
|
|
562
|
+
participant_ids : list[str] | None
|
|
563
|
+
Participant identifier for each item.
|
|
564
|
+
- For fixed effects (mode='fixed'): Pass None (automatically handled).
|
|
565
|
+
- For mixed effects (mode='random_intercepts' or 'random_slopes'):
|
|
566
|
+
Must provide list[str] with same length as items.
|
|
567
|
+
Must not contain empty strings.
|
|
568
|
+
validation_items : list[Item] | None
|
|
569
|
+
Optional validation items.
|
|
570
|
+
validation_labels : list[str] | None
|
|
571
|
+
Optional validation labels.
|
|
572
|
+
|
|
573
|
+
Returns
|
|
574
|
+
-------
|
|
575
|
+
dict[str, float]
|
|
576
|
+
Training metrics including:
|
|
577
|
+
- "train_accuracy", "train_loss": Standard metrics
|
|
578
|
+
- "participant_variance": σ²_u (if estimate_variance_components=True)
|
|
579
|
+
- "n_participants": Number of unique participants
|
|
580
|
+
- "residual_variance": σ²_ε (if estimated)
|
|
581
|
+
|
|
582
|
+
Raises
|
|
583
|
+
------
|
|
584
|
+
ValueError
|
|
585
|
+
If participant_ids is None when mode is 'random_intercepts'
|
|
586
|
+
or 'random_slopes'.
|
|
587
|
+
ValueError
|
|
588
|
+
If items, labels, and participant_ids have different lengths.
|
|
589
|
+
ValueError
|
|
590
|
+
If participant_ids contains empty strings.
|
|
591
|
+
ValueError
|
|
592
|
+
If validation data is incomplete.
|
|
593
|
+
ValueError
|
|
594
|
+
If labels are invalid for this task type.
|
|
595
|
+
"""
|
|
596
|
+
# Validate input lengths (handle both list[str] and list[list[str]] labels)
|
|
597
|
+
if labels and isinstance(labels[0], list):
|
|
598
|
+
# Cloze model: labels is list[list[str]]
|
|
599
|
+
if len(items) != len(labels):
|
|
600
|
+
raise ValueError(
|
|
601
|
+
f"Number of items ({len(items)}) must match "
|
|
602
|
+
f"number of labels ({len(labels)})"
|
|
603
|
+
)
|
|
604
|
+
else:
|
|
605
|
+
# Standard models: labels is list[str]
|
|
606
|
+
self._validate_items_labels_length(items, labels)
|
|
607
|
+
|
|
608
|
+
# Validate and normalize participant_ids
|
|
609
|
+
participant_ids = self._normalize_participant_ids(
|
|
610
|
+
participant_ids, items, self.config.mixed_effects.mode
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
if (validation_items is None) != (validation_labels is None):
|
|
614
|
+
raise ValueError(
|
|
615
|
+
"Both validation_items and validation_labels must be "
|
|
616
|
+
"provided, or neither"
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Prepare training data (model-specific)
|
|
620
|
+
(
|
|
621
|
+
prepared_items,
|
|
622
|
+
labels_numeric,
|
|
623
|
+
participant_ids,
|
|
624
|
+
validation_items,
|
|
625
|
+
validation_labels_numeric,
|
|
626
|
+
) = self._prepare_training_data(
|
|
627
|
+
items, labels, participant_ids, validation_items, validation_labels
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
# Initialize random effects
|
|
631
|
+
n_classes = self._get_n_classes_for_random_effects()
|
|
632
|
+
self._initialize_random_effects(n_classes)
|
|
633
|
+
|
|
634
|
+
# Register participants for adaptive regularization
|
|
635
|
+
if hasattr(self, "random_effects") and self.random_effects is not None:
|
|
636
|
+
participant_counts = Counter(participant_ids)
|
|
637
|
+
for pid, count in participant_counts.items():
|
|
638
|
+
self.random_effects.register_participant(pid, count)
|
|
639
|
+
|
|
640
|
+
# Perform training (model-specific)
|
|
641
|
+
metrics = self._do_training(
|
|
642
|
+
prepared_items,
|
|
643
|
+
labels_numeric,
|
|
644
|
+
participant_ids,
|
|
645
|
+
validation_items,
|
|
646
|
+
validation_labels_numeric,
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
self._is_fitted = True
|
|
650
|
+
|
|
651
|
+
# Estimate variance components
|
|
652
|
+
if (
|
|
653
|
+
self.config.mixed_effects.estimate_variance_components
|
|
654
|
+
and hasattr(self, "random_effects")
|
|
655
|
+
and self.random_effects is not None
|
|
656
|
+
):
|
|
657
|
+
var_comps = self.random_effects.estimate_variance_components()
|
|
658
|
+
if var_comps:
|
|
659
|
+
var_comp = var_comps.get("mu") or var_comps.get("slopes")
|
|
660
|
+
if var_comp:
|
|
661
|
+
if not hasattr(self, "variance_history"):
|
|
662
|
+
self.variance_history = []
|
|
663
|
+
self.variance_history.append(var_comp)
|
|
664
|
+
metrics["participant_variance"] = var_comp.variance
|
|
665
|
+
metrics["n_participants"] = var_comp.n_groups
|
|
666
|
+
|
|
667
|
+
return metrics
|
|
668
|
+
|
|
669
|
+
def predict(
|
|
670
|
+
self, items: list[Item], participant_ids: list[str] | None = None
|
|
671
|
+
) -> list[ModelPrediction]:
|
|
672
|
+
"""Predict class labels for items with participant identifiers.
|
|
673
|
+
|
|
674
|
+
Parameters
|
|
675
|
+
----------
|
|
676
|
+
items : list[Item]
|
|
677
|
+
Items to predict.
|
|
678
|
+
participant_ids : list[str] | None
|
|
679
|
+
Participant identifier for each item.
|
|
680
|
+
- For fixed effects (mode='fixed'): Pass None.
|
|
681
|
+
- For mixed effects: Must provide list[str] with same length as items.
|
|
682
|
+
- For unknown participants: Use population mean (prior) for random effects.
|
|
683
|
+
|
|
684
|
+
Returns
|
|
685
|
+
-------
|
|
686
|
+
list[ModelPrediction]
|
|
687
|
+
Predictions with probabilities and predicted class for each item.
|
|
688
|
+
|
|
689
|
+
Raises
|
|
690
|
+
------
|
|
691
|
+
ValueError
|
|
692
|
+
If model has not been trained.
|
|
693
|
+
ValueError
|
|
694
|
+
If participant_ids is None when mode requires mixed effects.
|
|
695
|
+
ValueError
|
|
696
|
+
If items and participant_ids have different lengths.
|
|
697
|
+
ValueError
|
|
698
|
+
If participant_ids contains empty strings.
|
|
699
|
+
ValueError
|
|
700
|
+
If items are incompatible with model.
|
|
701
|
+
"""
|
|
702
|
+
if not self._is_fitted:
|
|
703
|
+
raise ValueError("Model not trained. Call train() before predict().")
|
|
704
|
+
|
|
705
|
+
# Validate and normalize participant_ids
|
|
706
|
+
participant_ids = self._normalize_participant_ids(
|
|
707
|
+
participant_ids, items, self.config.mixed_effects.mode
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
return self._do_predict(items, participant_ids)
|
|
711
|
+
|
|
712
|
+
def predict_proba(
|
|
713
|
+
self, items: list[Item], participant_ids: list[str] | None = None
|
|
714
|
+
) -> np.ndarray:
|
|
715
|
+
"""Predict class probabilities for items with participant identifiers.
|
|
716
|
+
|
|
717
|
+
Parameters
|
|
718
|
+
----------
|
|
719
|
+
items : list[Item]
|
|
720
|
+
Items to predict.
|
|
721
|
+
participant_ids : list[str] | None
|
|
722
|
+
Participant identifier for each item.
|
|
723
|
+
- For fixed effects (mode='fixed'): Pass None.
|
|
724
|
+
- For mixed effects: Must provide list[str] with same length as items.
|
|
725
|
+
|
|
726
|
+
Returns
|
|
727
|
+
-------
|
|
728
|
+
np.ndarray
|
|
729
|
+
Array of shape (n_items, n_classes) with probabilities.
|
|
730
|
+
Each row sums to 1.0 for classification tasks.
|
|
731
|
+
|
|
732
|
+
Raises
|
|
733
|
+
------
|
|
734
|
+
ValueError
|
|
735
|
+
If model has not been trained.
|
|
736
|
+
ValueError
|
|
737
|
+
If participant_ids is None when mode requires mixed effects.
|
|
738
|
+
ValueError
|
|
739
|
+
If items and participant_ids have different lengths.
|
|
740
|
+
ValueError
|
|
741
|
+
If participant_ids contains empty strings.
|
|
742
|
+
ValueError
|
|
743
|
+
If items are incompatible with model.
|
|
744
|
+
"""
|
|
745
|
+
if not self._is_fitted:
|
|
746
|
+
raise ValueError("Model not trained. Call train() before predict_proba().")
|
|
747
|
+
|
|
748
|
+
# Validate and normalize participant_ids
|
|
749
|
+
participant_ids = self._normalize_participant_ids(
|
|
750
|
+
participant_ids, items, self.config.mixed_effects.mode
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
return self._do_predict_proba(items, participant_ids)
|
|
754
|
+
|
|
755
|
+
def save(self, path: str) -> None:
|
|
756
|
+
"""Save model to disk.
|
|
757
|
+
|
|
758
|
+
Parameters
|
|
759
|
+
----------
|
|
760
|
+
path : str
|
|
761
|
+
File or directory path to save the model.
|
|
762
|
+
|
|
763
|
+
Raises
|
|
764
|
+
------
|
|
765
|
+
ValueError
|
|
766
|
+
If model has not been trained.
|
|
767
|
+
"""
|
|
768
|
+
if not self._is_fitted:
|
|
769
|
+
raise ValueError("Model not trained. Call train() before save().")
|
|
770
|
+
|
|
771
|
+
save_path = Path(path)
|
|
772
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
773
|
+
|
|
774
|
+
# Save model-specific components
|
|
775
|
+
self._save_model_components(save_path)
|
|
776
|
+
|
|
777
|
+
# Save random effects (includes variance history)
|
|
778
|
+
if hasattr(self, "random_effects") and self.random_effects is not None:
|
|
779
|
+
# Copy variance_history from model to random_effects before saving
|
|
780
|
+
if hasattr(self, "variance_history"):
|
|
781
|
+
self.random_effects.variance_history = self.variance_history.copy()
|
|
782
|
+
self.random_effects.save(save_path / "random_effects")
|
|
783
|
+
|
|
784
|
+
# Save config with model-specific state
|
|
785
|
+
config_dict = self.config.model_dump()
|
|
786
|
+
save_state = self._get_save_state()
|
|
787
|
+
config_dict.update(save_state)
|
|
788
|
+
|
|
789
|
+
with open(save_path / "config.json", "w") as f:
|
|
790
|
+
json.dump(config_dict, f, indent=2)
|
|
791
|
+
|
|
792
|
+
def load(self, path: str) -> None:
|
|
793
|
+
"""Load model from disk.
|
|
794
|
+
|
|
795
|
+
Parameters
|
|
796
|
+
----------
|
|
797
|
+
path : str
|
|
798
|
+
File or directory path to load the model from.
|
|
799
|
+
|
|
800
|
+
Raises
|
|
801
|
+
------
|
|
802
|
+
FileNotFoundError
|
|
803
|
+
If model file/directory does not exist.
|
|
804
|
+
"""
|
|
805
|
+
load_path = Path(path)
|
|
806
|
+
if not load_path.exists():
|
|
807
|
+
raise FileNotFoundError(f"Model directory not found: {path}")
|
|
808
|
+
|
|
809
|
+
with open(load_path / "config.json") as f:
|
|
810
|
+
config_dict = json.load(f)
|
|
811
|
+
|
|
812
|
+
# Restore model-specific training state (before reconstructing config)
|
|
813
|
+
self._restore_training_state(config_dict)
|
|
814
|
+
|
|
815
|
+
# Load model-specific components (which will reconstruct the config)
|
|
816
|
+
# This must happen before initializing random effects so config is correct
|
|
817
|
+
self._load_model_components(load_path)
|
|
818
|
+
|
|
819
|
+
# Initialize and load random effects
|
|
820
|
+
n_classes = self._get_n_classes_for_random_effects()
|
|
821
|
+
from bead.active_learning.models.random_effects import ( # noqa: PLC0415
|
|
822
|
+
RandomEffectsManager,
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
# Check if model uses vocab_size instead of n_classes (e.g., ClozeModel)
|
|
826
|
+
if hasattr(self, "tokenizer") and hasattr(self.tokenizer, "vocab_size"):
|
|
827
|
+
# ClozeModel: use vocab_size
|
|
828
|
+
self.random_effects = RandomEffectsManager(
|
|
829
|
+
self.config.mixed_effects, vocab_size=n_classes
|
|
830
|
+
)
|
|
831
|
+
else:
|
|
832
|
+
# Standard models: use n_classes
|
|
833
|
+
self.random_effects = RandomEffectsManager(
|
|
834
|
+
self.config.mixed_effects, n_classes=n_classes
|
|
835
|
+
)
|
|
836
|
+
random_effects_path = load_path / "random_effects"
|
|
837
|
+
if random_effects_path.exists():
|
|
838
|
+
fixed_head = self._get_random_effects_fixed_head()
|
|
839
|
+
self.random_effects.load(random_effects_path, fixed_head=fixed_head)
|
|
840
|
+
# Restore variance history from random_effects
|
|
841
|
+
if hasattr(self.random_effects, "variance_history"):
|
|
842
|
+
if not hasattr(self, "variance_history"):
|
|
843
|
+
self.variance_history = []
|
|
844
|
+
self.variance_history = self.random_effects.variance_history.copy()
|
|
845
|
+
|
|
846
|
+
# Move to device (model-specific)
|
|
847
|
+
if hasattr(self, "encoder"):
|
|
848
|
+
self.encoder.to(self.config.device)
|
|
849
|
+
if hasattr(self, "model"):
|
|
850
|
+
self.model.to(self.config.device)
|
|
851
|
+
|
|
852
|
+
self._is_fitted = True
|