bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,811 @@
|
|
|
1
|
+
"""Ordinal scale model for ordered rating scales (Likert, sliders, etc.).
|
|
2
|
+
|
|
3
|
+
Implements truncated normal distribution for bounded continuous responses on [0, 1].
|
|
4
|
+
Supports GLMM with participant-level random effects (intercepts and slopes).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import tempfile
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from torch.distributions import Normal
|
|
17
|
+
from transformers import AutoModel, AutoTokenizer, TrainingArguments
|
|
18
|
+
|
|
19
|
+
from bead.active_learning.config import MixedEffectsConfig, VarianceComponents
|
|
20
|
+
from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
|
|
21
|
+
from bead.active_learning.models.random_effects import RandomEffectsManager
|
|
22
|
+
from bead.active_learning.trainers.data_collator import MixedEffectsDataCollator
|
|
23
|
+
from bead.active_learning.trainers.dataset_utils import items_to_dataset
|
|
24
|
+
from bead.active_learning.trainers.metrics import compute_regression_metrics
|
|
25
|
+
from bead.active_learning.trainers.model_wrapper import EncoderRegressionWrapper
|
|
26
|
+
from bead.config.active_learning import OrdinalScaleModelConfig
|
|
27
|
+
from bead.items.item import Item
|
|
28
|
+
from bead.items.item_template import ItemTemplate, TaskType
|
|
29
|
+
|
|
30
|
+
__all__ = ["OrdinalScaleModel"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OrdinalScaleModel(ActiveLearningModel):
|
|
34
|
+
"""Model for ordinal_scale tasks with bounded continuous responses.
|
|
35
|
+
|
|
36
|
+
Uses truncated normal distribution on [scale_min, scale_max] to model
|
|
37
|
+
slider/Likert responses while properly handling endpoints (0 and 1).
|
|
38
|
+
Supports three modes: fixed effects, random intercepts, random slopes.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
config : OrdinalScaleModelConfig
|
|
43
|
+
Configuration object containing all model parameters.
|
|
44
|
+
|
|
45
|
+
Attributes
|
|
46
|
+
----------
|
|
47
|
+
config : OrdinalScaleModelConfig
|
|
48
|
+
Model configuration.
|
|
49
|
+
tokenizer : AutoTokenizer
|
|
50
|
+
Transformer tokenizer.
|
|
51
|
+
encoder : AutoModel
|
|
52
|
+
Transformer encoder model.
|
|
53
|
+
regression_head : nn.Sequential
|
|
54
|
+
Regression head (fixed effects head) - outputs continuous μ.
|
|
55
|
+
random_effects : RandomEffectsManager
|
|
56
|
+
Manager for participant-level random effects.
|
|
57
|
+
variance_history : list[VarianceComponents]
|
|
58
|
+
Variance component estimates over training (for diagnostics).
|
|
59
|
+
_is_fitted : bool
|
|
60
|
+
Whether model has been trained.
|
|
61
|
+
|
|
62
|
+
Examples
|
|
63
|
+
--------
|
|
64
|
+
>>> from uuid import uuid4
|
|
65
|
+
>>> from bead.items.item import Item
|
|
66
|
+
>>> from bead.config.active_learning import OrdinalScaleModelConfig
|
|
67
|
+
>>> items = [
|
|
68
|
+
... Item(
|
|
69
|
+
... item_template_id=uuid4(),
|
|
70
|
+
... rendered_elements={"text": f"Sentence {i}"}
|
|
71
|
+
... )
|
|
72
|
+
... for i in range(10)
|
|
73
|
+
... ]
|
|
74
|
+
>>> labels = ["0.3", "0.7"] * 5 # Continuous values as strings
|
|
75
|
+
>>> config = OrdinalScaleModelConfig( # doctest: +SKIP
|
|
76
|
+
... num_epochs=1, batch_size=2, device="cpu"
|
|
77
|
+
... )
|
|
78
|
+
>>> model = OrdinalScaleModel(config=config) # doctest: +SKIP
|
|
79
|
+
>>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
|
|
80
|
+
>>> predictions = model.predict(items[:3], participant_ids=None) # doctest: +SKIP
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
config: OrdinalScaleModelConfig | None = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
"""Initialize ordinal scale model.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
config : OrdinalScaleModelConfig | None
|
|
92
|
+
Configuration object. If None, uses default configuration.
|
|
93
|
+
"""
|
|
94
|
+
self.config = config or OrdinalScaleModelConfig()
|
|
95
|
+
|
|
96
|
+
# Validate mixed_effects configuration
|
|
97
|
+
super().__init__(self.config)
|
|
98
|
+
|
|
99
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
|
100
|
+
self.encoder = AutoModel.from_pretrained(self.config.model_name)
|
|
101
|
+
|
|
102
|
+
self.regression_head: nn.Sequential | None = None
|
|
103
|
+
self._is_fitted = False
|
|
104
|
+
|
|
105
|
+
# Initialize random effects manager
|
|
106
|
+
self.random_effects: RandomEffectsManager | None = None
|
|
107
|
+
self.variance_history: list[VarianceComponents] = []
|
|
108
|
+
|
|
109
|
+
self.encoder.to(self.config.device)
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def supported_task_types(self) -> list[TaskType]:
|
|
113
|
+
"""Get supported task types.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
list[TaskType]
|
|
118
|
+
List containing "ordinal_scale".
|
|
119
|
+
"""
|
|
120
|
+
return ["ordinal_scale"]
|
|
121
|
+
|
|
122
|
+
def validate_item_compatibility(
|
|
123
|
+
self, item: Item, item_template: ItemTemplate
|
|
124
|
+
) -> None:
|
|
125
|
+
"""Validate item is compatible with ordinal scale model.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
item : Item
|
|
130
|
+
Item to validate.
|
|
131
|
+
item_template : ItemTemplate
|
|
132
|
+
Template the item was constructed from.
|
|
133
|
+
|
|
134
|
+
Raises
|
|
135
|
+
------
|
|
136
|
+
ValueError
|
|
137
|
+
If task_type is not "ordinal_scale".
|
|
138
|
+
"""
|
|
139
|
+
if item_template.task_type != "ordinal_scale":
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"Expected task_type 'ordinal_scale', got '{item_template.task_type}'"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def _initialize_regression_head(self) -> None:
|
|
145
|
+
"""Initialize regression head for continuous output μ."""
|
|
146
|
+
hidden_size = self.encoder.config.hidden_size
|
|
147
|
+
|
|
148
|
+
# Single output for location parameter μ
|
|
149
|
+
self.regression_head = nn.Sequential(
|
|
150
|
+
nn.Linear(hidden_size, 256),
|
|
151
|
+
nn.ReLU(),
|
|
152
|
+
nn.Dropout(0.1),
|
|
153
|
+
nn.Linear(256, 1), # Output μ (location parameter)
|
|
154
|
+
)
|
|
155
|
+
self.regression_head.to(self.config.device)
|
|
156
|
+
|
|
157
|
+
def _encode_texts(self, texts: list[str]) -> torch.Tensor:
|
|
158
|
+
"""Encode texts using transformer.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
texts : list[str]
|
|
163
|
+
Texts to encode.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
torch.Tensor
|
|
168
|
+
Encoded representations of shape (batch_size, hidden_size).
|
|
169
|
+
"""
|
|
170
|
+
encodings = self.tokenizer(
|
|
171
|
+
texts,
|
|
172
|
+
padding=True,
|
|
173
|
+
truncation=True,
|
|
174
|
+
max_length=self.config.max_length,
|
|
175
|
+
return_tensors="pt",
|
|
176
|
+
)
|
|
177
|
+
encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
|
|
178
|
+
|
|
179
|
+
outputs = self.encoder(**encodings)
|
|
180
|
+
return outputs.last_hidden_state[:, 0, :]
|
|
181
|
+
|
|
182
|
+
def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
|
|
183
|
+
"""Prepare inputs for encoding.
|
|
184
|
+
|
|
185
|
+
For ordinal scale tasks, concatenates all rendered elements.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
items : list[Item]
|
|
190
|
+
Items to encode.
|
|
191
|
+
|
|
192
|
+
Returns
|
|
193
|
+
-------
|
|
194
|
+
torch.Tensor
|
|
195
|
+
Encoded representations.
|
|
196
|
+
"""
|
|
197
|
+
texts = []
|
|
198
|
+
for item in items:
|
|
199
|
+
# Concatenate all rendered elements
|
|
200
|
+
all_text = " ".join(item.rendered_elements.values())
|
|
201
|
+
texts.append(all_text)
|
|
202
|
+
return self._encode_texts(texts)
|
|
203
|
+
|
|
204
|
+
def _truncated_normal_log_prob(
|
|
205
|
+
self, y: torch.Tensor, mu: torch.Tensor, sigma: float
|
|
206
|
+
) -> torch.Tensor:
|
|
207
|
+
"""Compute log probability of truncated normal distribution.
|
|
208
|
+
|
|
209
|
+
Uses truncated normal on [scale_min, scale_max] to properly handle
|
|
210
|
+
endpoint responses (0.0 and 1.0) without arbitrary nudging.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
y : torch.Tensor
|
|
215
|
+
Observed values, shape (batch,).
|
|
216
|
+
mu : torch.Tensor
|
|
217
|
+
Location parameters (before truncation), shape (batch,).
|
|
218
|
+
sigma : float
|
|
219
|
+
Scale parameter (standard deviation).
|
|
220
|
+
|
|
221
|
+
Returns
|
|
222
|
+
-------
|
|
223
|
+
torch.Tensor
|
|
224
|
+
Log probabilities, shape (batch,).
|
|
225
|
+
"""
|
|
226
|
+
base_dist = Normal(mu.squeeze(), sigma)
|
|
227
|
+
|
|
228
|
+
# Unnormalized log prob
|
|
229
|
+
log_prob_unnorm = base_dist.log_prob(y)
|
|
230
|
+
|
|
231
|
+
# Normalizer: log(Φ((high-μ)/σ) - Φ((low-μ)/σ))
|
|
232
|
+
alpha = (self.config.scale.min - mu.squeeze()) / sigma
|
|
233
|
+
beta = (self.config.scale.max - mu.squeeze()) / sigma
|
|
234
|
+
normalizer = base_dist.cdf(beta) - base_dist.cdf(alpha)
|
|
235
|
+
|
|
236
|
+
# Clamp to avoid log(0)
|
|
237
|
+
normalizer = torch.clamp(normalizer, min=1e-8)
|
|
238
|
+
log_normalizer = torch.log(normalizer)
|
|
239
|
+
|
|
240
|
+
return log_prob_unnorm - log_normalizer
|
|
241
|
+
|
|
242
|
+
def _prepare_training_data(
|
|
243
|
+
self,
|
|
244
|
+
items: list[Item],
|
|
245
|
+
labels: list[str],
|
|
246
|
+
participant_ids: list[str],
|
|
247
|
+
validation_items: list[Item] | None,
|
|
248
|
+
validation_labels: list[str] | None,
|
|
249
|
+
) -> tuple[
|
|
250
|
+
list[Item], list[float], list[str], list[Item] | None, list[float] | None
|
|
251
|
+
]:
|
|
252
|
+
"""Prepare training data for ordinal scale model.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
items : list[Item]
|
|
257
|
+
Training items.
|
|
258
|
+
labels : list[str]
|
|
259
|
+
Training labels (continuous values as strings).
|
|
260
|
+
participant_ids : list[str]
|
|
261
|
+
Normalized participant IDs.
|
|
262
|
+
validation_items : list[Item] | None
|
|
263
|
+
Validation items.
|
|
264
|
+
validation_labels : list[str] | None
|
|
265
|
+
Validation labels.
|
|
266
|
+
|
|
267
|
+
Returns
|
|
268
|
+
-------
|
|
269
|
+
tuple[list[Item], list[float], list[str], list[Item] | None, list[float] | None]
|
|
270
|
+
Prepared items, numeric labels (floats), participant_ids,
|
|
271
|
+
validation_items, numeric validation_labels.
|
|
272
|
+
"""
|
|
273
|
+
# Parse labels to floats and validate bounds
|
|
274
|
+
try:
|
|
275
|
+
y_values = [float(label) for label in labels]
|
|
276
|
+
except ValueError as e:
|
|
277
|
+
raise ValueError(
|
|
278
|
+
f"Labels must be numeric strings (e.g., '0.5', '0.75'). Got error: {e}"
|
|
279
|
+
) from e
|
|
280
|
+
|
|
281
|
+
# Validate all values are within bounds
|
|
282
|
+
for i, val in enumerate(y_values):
|
|
283
|
+
if not (self.config.scale.min <= val <= self.config.scale.max):
|
|
284
|
+
raise ValueError(
|
|
285
|
+
f"Label at index {i} ({val}) is outside bounds "
|
|
286
|
+
f"[{self.config.scale.min}, {self.config.scale.max}]"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
self._initialize_regression_head()
|
|
290
|
+
|
|
291
|
+
# Convert validation labels if provided
|
|
292
|
+
val_y_numeric = None
|
|
293
|
+
if validation_items is not None and validation_labels is not None:
|
|
294
|
+
try:
|
|
295
|
+
val_y_numeric = [float(label) for label in validation_labels]
|
|
296
|
+
except ValueError as e:
|
|
297
|
+
raise ValueError(
|
|
298
|
+
f"Validation labels must be numeric strings. Got error: {e}"
|
|
299
|
+
) from e
|
|
300
|
+
|
|
301
|
+
# Validate bounds for validation labels
|
|
302
|
+
for i, val in enumerate(val_y_numeric):
|
|
303
|
+
if not (self.config.scale.min <= val <= self.config.scale.max):
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"Validation label at index {i} ({val}) is outside bounds "
|
|
306
|
+
f"[{self.config.scale.min}, {self.config.scale.max}]"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return items, y_values, participant_ids, validation_items, val_y_numeric
|
|
310
|
+
|
|
311
|
+
def _initialize_random_effects(self, n_classes: int) -> None:
|
|
312
|
+
"""Initialize random effects manager.
|
|
313
|
+
|
|
314
|
+
Parameters
|
|
315
|
+
----------
|
|
316
|
+
n_classes : int
|
|
317
|
+
Number of classes (1 for regression).
|
|
318
|
+
"""
|
|
319
|
+
self.random_effects = RandomEffectsManager(
|
|
320
|
+
self.config.mixed_effects,
|
|
321
|
+
n_classes=n_classes, # Scalar bias for μ
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def _do_training(
|
|
325
|
+
self,
|
|
326
|
+
items: list[Item],
|
|
327
|
+
labels_numeric: list[float],
|
|
328
|
+
participant_ids: list[str],
|
|
329
|
+
validation_items: list[Item] | None,
|
|
330
|
+
validation_labels_numeric: list[float] | None,
|
|
331
|
+
) -> dict[str, float]:
|
|
332
|
+
"""Perform ordinal scale model training.
|
|
333
|
+
|
|
334
|
+
Parameters
|
|
335
|
+
----------
|
|
336
|
+
items : list[Item]
|
|
337
|
+
Training items.
|
|
338
|
+
labels_numeric : list[float]
|
|
339
|
+
Numeric labels (continuous values).
|
|
340
|
+
participant_ids : list[str]
|
|
341
|
+
Participant IDs.
|
|
342
|
+
validation_items : list[Item] | None
|
|
343
|
+
Validation items.
|
|
344
|
+
validation_labels_numeric : list[float] | None
|
|
345
|
+
Numeric validation labels.
|
|
346
|
+
|
|
347
|
+
Returns
|
|
348
|
+
-------
|
|
349
|
+
dict[str, float]
|
|
350
|
+
Training metrics.
|
|
351
|
+
"""
|
|
352
|
+
# Convert validation_labels_numeric back to string labels for validation metrics
|
|
353
|
+
validation_labels = None
|
|
354
|
+
if validation_items is not None and validation_labels_numeric is not None:
|
|
355
|
+
validation_labels = [str(val) for val in validation_labels_numeric]
|
|
356
|
+
|
|
357
|
+
# Use HuggingFace Trainer for fixed and random_intercepts modes
|
|
358
|
+
# random_slopes requires custom loop due to per-participant heads
|
|
359
|
+
use_huggingface_trainer = self.config.mixed_effects.mode in (
|
|
360
|
+
"fixed",
|
|
361
|
+
"random_intercepts",
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
if use_huggingface_trainer:
|
|
365
|
+
metrics = self._train_with_huggingface_trainer(
|
|
366
|
+
items,
|
|
367
|
+
labels_numeric,
|
|
368
|
+
participant_ids,
|
|
369
|
+
validation_items,
|
|
370
|
+
validation_labels,
|
|
371
|
+
)
|
|
372
|
+
else:
|
|
373
|
+
# Use custom training loop for random_slopes
|
|
374
|
+
metrics = self._train_with_custom_loop(
|
|
375
|
+
items,
|
|
376
|
+
labels_numeric,
|
|
377
|
+
participant_ids,
|
|
378
|
+
validation_items,
|
|
379
|
+
validation_labels,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# Add validation MSE if validation data provided and not already computed
|
|
383
|
+
if (
|
|
384
|
+
validation_items is not None
|
|
385
|
+
and validation_labels is not None
|
|
386
|
+
and "val_mse" not in metrics
|
|
387
|
+
):
|
|
388
|
+
# Validation with placeholder participant_ids for mixed effects
|
|
389
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
390
|
+
val_participant_ids = ["_fixed_"] * len(validation_items)
|
|
391
|
+
else:
|
|
392
|
+
val_participant_ids = ["_validation_"] * len(validation_items)
|
|
393
|
+
val_predictions = self._do_predict(validation_items, val_participant_ids)
|
|
394
|
+
val_pred_values = [float(p.predicted_class) for p in val_predictions]
|
|
395
|
+
val_true_values = [float(label) for label in validation_labels]
|
|
396
|
+
val_mse = np.mean(
|
|
397
|
+
[
|
|
398
|
+
(pred - true) ** 2
|
|
399
|
+
for pred, true in zip(val_pred_values, val_true_values, strict=True)
|
|
400
|
+
]
|
|
401
|
+
)
|
|
402
|
+
metrics["val_mse"] = val_mse
|
|
403
|
+
|
|
404
|
+
return metrics
|
|
405
|
+
|
|
406
|
+
def _train_with_huggingface_trainer(
|
|
407
|
+
self,
|
|
408
|
+
items: list[Item],
|
|
409
|
+
y_numeric: list[float],
|
|
410
|
+
participant_ids: list[str],
|
|
411
|
+
validation_items: list[Item] | None,
|
|
412
|
+
validation_labels: list[str] | None,
|
|
413
|
+
) -> dict[str, float]:
|
|
414
|
+
"""Train using HuggingFace Trainer with mixed effects support for regression.
|
|
415
|
+
|
|
416
|
+
Parameters
|
|
417
|
+
----------
|
|
418
|
+
items : list[Item]
|
|
419
|
+
Training items.
|
|
420
|
+
y_numeric : list[float]
|
|
421
|
+
Numeric labels (continuous values).
|
|
422
|
+
participant_ids : list[str]
|
|
423
|
+
Participant IDs.
|
|
424
|
+
validation_items : list[Item] | None
|
|
425
|
+
Validation items.
|
|
426
|
+
validation_labels : list[str] | None
|
|
427
|
+
Validation labels.
|
|
428
|
+
|
|
429
|
+
Returns
|
|
430
|
+
-------
|
|
431
|
+
dict[str, float]
|
|
432
|
+
Training metrics.
|
|
433
|
+
"""
|
|
434
|
+
# Convert items to HuggingFace Dataset
|
|
435
|
+
train_dataset = items_to_dataset(
|
|
436
|
+
items=items,
|
|
437
|
+
labels=y_numeric,
|
|
438
|
+
participant_ids=participant_ids,
|
|
439
|
+
tokenizer=self.tokenizer,
|
|
440
|
+
max_length=self.config.max_length,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
eval_dataset = None
|
|
444
|
+
if validation_items is not None and validation_labels is not None:
|
|
445
|
+
val_y_numeric = [float(label) for label in validation_labels]
|
|
446
|
+
val_participant_ids = (
|
|
447
|
+
["_validation_"] * len(validation_items)
|
|
448
|
+
if self.config.mixed_effects.mode != "fixed"
|
|
449
|
+
else ["_fixed_"] * len(validation_items)
|
|
450
|
+
)
|
|
451
|
+
eval_dataset = items_to_dataset(
|
|
452
|
+
items=validation_items,
|
|
453
|
+
labels=val_y_numeric,
|
|
454
|
+
participant_ids=val_participant_ids,
|
|
455
|
+
tokenizer=self.tokenizer,
|
|
456
|
+
max_length=self.config.max_length,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Wrap the encoder and regression head for Trainer
|
|
460
|
+
wrapped_model = EncoderRegressionWrapper(
|
|
461
|
+
encoder=self.encoder, regression_head=self.regression_head
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Create data collator
|
|
465
|
+
data_collator = MixedEffectsDataCollator(tokenizer=self.tokenizer)
|
|
466
|
+
|
|
467
|
+
# Create training arguments with checkpointing
|
|
468
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
469
|
+
checkpoint_dir = Path(tmpdir) / "checkpoints"
|
|
470
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
471
|
+
|
|
472
|
+
training_args = TrainingArguments(
|
|
473
|
+
output_dir=str(checkpoint_dir),
|
|
474
|
+
num_train_epochs=self.config.num_epochs,
|
|
475
|
+
per_device_train_batch_size=self.config.batch_size,
|
|
476
|
+
per_device_eval_batch_size=self.config.batch_size,
|
|
477
|
+
learning_rate=self.config.learning_rate,
|
|
478
|
+
logging_steps=10,
|
|
479
|
+
eval_strategy="epoch" if eval_dataset is not None else "no",
|
|
480
|
+
save_strategy="epoch",
|
|
481
|
+
save_total_limit=1,
|
|
482
|
+
load_best_model_at_end=False,
|
|
483
|
+
report_to="none",
|
|
484
|
+
remove_unused_columns=False,
|
|
485
|
+
use_cpu=self.config.device == "cpu",
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
# Import here to avoid circular import
|
|
489
|
+
from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
|
|
490
|
+
MixedEffectsTrainer,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Create trainer
|
|
494
|
+
trainer = MixedEffectsTrainer(
|
|
495
|
+
model=wrapped_model,
|
|
496
|
+
args=training_args,
|
|
497
|
+
train_dataset=train_dataset,
|
|
498
|
+
eval_dataset=eval_dataset,
|
|
499
|
+
data_collator=data_collator,
|
|
500
|
+
tokenizer=self.tokenizer,
|
|
501
|
+
random_effects_manager=self.random_effects,
|
|
502
|
+
compute_metrics=compute_regression_metrics,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# Train
|
|
506
|
+
train_result = trainer.train()
|
|
507
|
+
|
|
508
|
+
# Get training metrics
|
|
509
|
+
train_metrics = trainer.evaluate(eval_dataset=train_dataset)
|
|
510
|
+
metrics: dict[str, float] = {
|
|
511
|
+
"train_loss": float(train_result.training_loss),
|
|
512
|
+
"train_mse": train_metrics.get("eval_mse", 0.0),
|
|
513
|
+
"train_mae": train_metrics.get("eval_mae", 0.0),
|
|
514
|
+
"train_r2": train_metrics.get("eval_r2", 0.0),
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
# Get validation metrics if eval_dataset was provided
|
|
518
|
+
if eval_dataset is not None:
|
|
519
|
+
val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
|
520
|
+
metrics.update(
|
|
521
|
+
{
|
|
522
|
+
"val_mse": val_metrics.get("eval_mse", 0.0),
|
|
523
|
+
"val_mae": val_metrics.get("eval_mae", 0.0),
|
|
524
|
+
"val_r2": val_metrics.get("eval_r2", 0.0),
|
|
525
|
+
}
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
return metrics
|
|
529
|
+
|
|
530
|
+
def _train_with_custom_loop(
|
|
531
|
+
self,
|
|
532
|
+
items: list[Item],
|
|
533
|
+
y_numeric: list[float],
|
|
534
|
+
participant_ids: list[str],
|
|
535
|
+
validation_items: list[Item] | None,
|
|
536
|
+
validation_labels: list[str] | None,
|
|
537
|
+
) -> dict[str, float]:
|
|
538
|
+
"""Train using custom loop for random_slopes mode.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
items : list[Item]
|
|
543
|
+
Training items.
|
|
544
|
+
y_numeric : list[float]
|
|
545
|
+
Numeric labels (continuous values).
|
|
546
|
+
participant_ids : list[str]
|
|
547
|
+
Participant IDs.
|
|
548
|
+
validation_items : list[Item] | None
|
|
549
|
+
Validation items.
|
|
550
|
+
validation_labels : list[str] | None
|
|
551
|
+
Validation labels.
|
|
552
|
+
|
|
553
|
+
Returns
|
|
554
|
+
-------
|
|
555
|
+
dict[str, float]
|
|
556
|
+
Training metrics.
|
|
557
|
+
"""
|
|
558
|
+
y = torch.tensor(y_numeric, dtype=torch.float, device=self.config.device)
|
|
559
|
+
|
|
560
|
+
# Build optimizer parameters
|
|
561
|
+
params_to_optimize = list(self.encoder.parameters()) + list(
|
|
562
|
+
self.regression_head.parameters()
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Add random effects parameters for random_slopes
|
|
566
|
+
for head in self.random_effects.slopes.values():
|
|
567
|
+
params_to_optimize.extend(head.parameters())
|
|
568
|
+
|
|
569
|
+
optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
|
|
570
|
+
|
|
571
|
+
self.encoder.train()
|
|
572
|
+
self.regression_head.train()
|
|
573
|
+
|
|
574
|
+
for _epoch in range(self.config.num_epochs):
|
|
575
|
+
n_batches = (
|
|
576
|
+
len(items) + self.config.batch_size - 1
|
|
577
|
+
) // self.config.batch_size
|
|
578
|
+
epoch_loss = 0.0
|
|
579
|
+
epoch_mse = 0.0
|
|
580
|
+
|
|
581
|
+
for i in range(n_batches):
|
|
582
|
+
start_idx = i * self.config.batch_size
|
|
583
|
+
end_idx = min(start_idx + self.config.batch_size, len(items))
|
|
584
|
+
|
|
585
|
+
batch_items = items[start_idx:end_idx]
|
|
586
|
+
batch_labels = y[start_idx:end_idx]
|
|
587
|
+
batch_participant_ids = participant_ids[start_idx:end_idx]
|
|
588
|
+
|
|
589
|
+
embeddings = self._prepare_inputs(batch_items)
|
|
590
|
+
|
|
591
|
+
# Per-participant head for random_slopes
|
|
592
|
+
mu_list = []
|
|
593
|
+
for j, pid in enumerate(batch_participant_ids):
|
|
594
|
+
participant_head = self.random_effects.get_slopes(
|
|
595
|
+
pid,
|
|
596
|
+
fixed_head=self.regression_head,
|
|
597
|
+
create_if_missing=True,
|
|
598
|
+
)
|
|
599
|
+
mu_j = participant_head(embeddings[j : j + 1]).squeeze()
|
|
600
|
+
mu_list.append(mu_j)
|
|
601
|
+
mu = torch.stack(mu_list)
|
|
602
|
+
|
|
603
|
+
# Negative log-likelihood of truncated normal
|
|
604
|
+
log_probs = self._truncated_normal_log_prob(
|
|
605
|
+
batch_labels, mu, self.config.sigma
|
|
606
|
+
)
|
|
607
|
+
loss_nll = -log_probs.mean()
|
|
608
|
+
|
|
609
|
+
# Add prior regularization
|
|
610
|
+
loss_prior = self.random_effects.compute_prior_loss()
|
|
611
|
+
loss = loss_nll + loss_prior
|
|
612
|
+
|
|
613
|
+
optimizer.zero_grad()
|
|
614
|
+
loss.backward()
|
|
615
|
+
optimizer.step()
|
|
616
|
+
|
|
617
|
+
epoch_loss += loss.item()
|
|
618
|
+
# Also track MSE for interpretability
|
|
619
|
+
mse = ((mu - batch_labels) ** 2).mean().item()
|
|
620
|
+
epoch_mse += mse
|
|
621
|
+
|
|
622
|
+
epoch_loss = epoch_loss / n_batches
|
|
623
|
+
epoch_mse = epoch_mse / n_batches
|
|
624
|
+
|
|
625
|
+
metrics: dict[str, float] = {
|
|
626
|
+
"train_loss": epoch_loss,
|
|
627
|
+
"train_mse": epoch_mse,
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
return metrics
|
|
631
|
+
|
|
632
|
+
def _do_predict(
|
|
633
|
+
self, items: list[Item], participant_ids: list[str]
|
|
634
|
+
) -> list[ModelPrediction]:
|
|
635
|
+
"""Perform ordinal scale model prediction.
|
|
636
|
+
|
|
637
|
+
Parameters
|
|
638
|
+
----------
|
|
639
|
+
items : list[Item]
|
|
640
|
+
Items to predict.
|
|
641
|
+
participant_ids : list[str]
|
|
642
|
+
Normalized participant IDs.
|
|
643
|
+
|
|
644
|
+
Returns
|
|
645
|
+
-------
|
|
646
|
+
list[ModelPrediction]
|
|
647
|
+
Predictions with predicted_class as string representation of value.
|
|
648
|
+
"""
|
|
649
|
+
self.encoder.eval()
|
|
650
|
+
self.regression_head.eval()
|
|
651
|
+
|
|
652
|
+
with torch.no_grad():
|
|
653
|
+
embeddings = self._prepare_inputs(items)
|
|
654
|
+
|
|
655
|
+
# Forward pass depends on mixed effects mode
|
|
656
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
657
|
+
mu = self.regression_head(embeddings).squeeze(1)
|
|
658
|
+
|
|
659
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
660
|
+
mu = self.regression_head(embeddings).squeeze(1)
|
|
661
|
+
for i, pid in enumerate(participant_ids):
|
|
662
|
+
# Unknown participants: use prior mean (zero bias)
|
|
663
|
+
bias = self.random_effects.get_intercepts(
|
|
664
|
+
pid, n_classes=1, param_name="mu", create_if_missing=False
|
|
665
|
+
)
|
|
666
|
+
mu[i] = mu[i] + bias.item()
|
|
667
|
+
|
|
668
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
669
|
+
mu_list = []
|
|
670
|
+
for i, pid in enumerate(participant_ids):
|
|
671
|
+
# Unknown participants: use fixed head
|
|
672
|
+
participant_head = self.random_effects.get_slopes(
|
|
673
|
+
pid, fixed_head=self.regression_head, create_if_missing=False
|
|
674
|
+
)
|
|
675
|
+
mu_i = participant_head(embeddings[i : i + 1]).squeeze()
|
|
676
|
+
mu_list.append(mu_i)
|
|
677
|
+
mu = torch.stack(mu_list)
|
|
678
|
+
|
|
679
|
+
# Clamp predictions to bounds
|
|
680
|
+
mu = torch.clamp(mu, self.config.scale.min, self.config.scale.max)
|
|
681
|
+
predictions_array = mu.cpu().numpy()
|
|
682
|
+
|
|
683
|
+
predictions = []
|
|
684
|
+
for i, item in enumerate(items):
|
|
685
|
+
pred_value = float(predictions_array[i])
|
|
686
|
+
predictions.append(
|
|
687
|
+
ModelPrediction(
|
|
688
|
+
item_id=str(item.id),
|
|
689
|
+
probabilities={}, # Not applicable for regression
|
|
690
|
+
predicted_class=str(pred_value), # Continuous value as string
|
|
691
|
+
confidence=1.0, # Not applicable for regression
|
|
692
|
+
)
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
return predictions
|
|
696
|
+
|
|
697
|
+
def _do_predict_proba(
|
|
698
|
+
self, items: list[Item], participant_ids: list[str]
|
|
699
|
+
) -> np.ndarray:
|
|
700
|
+
"""Perform ordinal scale model probability prediction.
|
|
701
|
+
|
|
702
|
+
For ordinal scale regression, returns μ values directly.
|
|
703
|
+
|
|
704
|
+
Parameters
|
|
705
|
+
----------
|
|
706
|
+
items : list[Item]
|
|
707
|
+
Items to predict.
|
|
708
|
+
participant_ids : list[str]
|
|
709
|
+
Normalized participant IDs.
|
|
710
|
+
|
|
711
|
+
Returns
|
|
712
|
+
-------
|
|
713
|
+
np.ndarray
|
|
714
|
+
Array of shape (n_items, 1) with predicted μ values.
|
|
715
|
+
"""
|
|
716
|
+
predictions = self._do_predict(items, participant_ids)
|
|
717
|
+
return np.array([[float(p.predicted_class)] for p in predictions])
|
|
718
|
+
|
|
719
|
+
def _save_model_components(self, save_path: Path) -> None:
|
|
720
|
+
"""Save model-specific components.
|
|
721
|
+
|
|
722
|
+
Parameters
|
|
723
|
+
----------
|
|
724
|
+
save_path : Path
|
|
725
|
+
Directory to save to.
|
|
726
|
+
"""
|
|
727
|
+
self.encoder.save_pretrained(save_path / "encoder")
|
|
728
|
+
self.tokenizer.save_pretrained(save_path / "encoder")
|
|
729
|
+
|
|
730
|
+
torch.save(
|
|
731
|
+
self.regression_head.state_dict(),
|
|
732
|
+
save_path / "regression_head.pt",
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
def _get_save_state(self) -> dict[str, object]:
|
|
736
|
+
"""Get model-specific state to save.
|
|
737
|
+
|
|
738
|
+
Returns
|
|
739
|
+
-------
|
|
740
|
+
dict[str, object]
|
|
741
|
+
State dictionary.
|
|
742
|
+
"""
|
|
743
|
+
return {}
|
|
744
|
+
|
|
745
|
+
def _restore_training_state(self, config_dict: dict[str, object]) -> None:
|
|
746
|
+
"""Restore model-specific training state.
|
|
747
|
+
|
|
748
|
+
Parameters
|
|
749
|
+
----------
|
|
750
|
+
config_dict : dict[str, object]
|
|
751
|
+
Configuration dictionary with training state.
|
|
752
|
+
"""
|
|
753
|
+
# OrdinalScaleModel doesn't have additional training state to restore
|
|
754
|
+
pass
|
|
755
|
+
|
|
756
|
+
def _load_model_components(self, load_path: Path) -> None:
|
|
757
|
+
"""Load model-specific components.
|
|
758
|
+
|
|
759
|
+
Parameters
|
|
760
|
+
----------
|
|
761
|
+
load_path : Path
|
|
762
|
+
Directory to load from.
|
|
763
|
+
"""
|
|
764
|
+
# Load config.json to reconstruct config
|
|
765
|
+
with open(load_path / "config.json") as f:
|
|
766
|
+
config_dict = json.load(f)
|
|
767
|
+
|
|
768
|
+
# Reconstruct MixedEffectsConfig if needed
|
|
769
|
+
if "mixed_effects" in config_dict and isinstance(
|
|
770
|
+
config_dict["mixed_effects"], dict
|
|
771
|
+
):
|
|
772
|
+
config_dict["mixed_effects"] = MixedEffectsConfig(
|
|
773
|
+
**config_dict["mixed_effects"]
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
self.config = OrdinalScaleModelConfig(**config_dict)
|
|
777
|
+
|
|
778
|
+
self.encoder = AutoModel.from_pretrained(load_path / "encoder")
|
|
779
|
+
self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
|
|
780
|
+
|
|
781
|
+
self._initialize_regression_head()
|
|
782
|
+
self.regression_head.load_state_dict(
|
|
783
|
+
torch.load(
|
|
784
|
+
load_path / "regression_head.pt", map_location=self.config.device
|
|
785
|
+
)
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
self.encoder.to(self.config.device)
|
|
789
|
+
self.regression_head.to(self.config.device)
|
|
790
|
+
|
|
791
|
+
def _get_n_classes_for_random_effects(self) -> int:
|
|
792
|
+
"""Get the number of classes for initializing RandomEffectsManager.
|
|
793
|
+
|
|
794
|
+
For ordinal scale models, this is 1 (scalar bias).
|
|
795
|
+
|
|
796
|
+
Returns
|
|
797
|
+
-------
|
|
798
|
+
int
|
|
799
|
+
Always 1 for regression.
|
|
800
|
+
"""
|
|
801
|
+
return 1
|
|
802
|
+
|
|
803
|
+
def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
|
|
804
|
+
"""Get the fixed head for random effects.
|
|
805
|
+
|
|
806
|
+
Returns
|
|
807
|
+
-------
|
|
808
|
+
torch.nn.Module | None
|
|
809
|
+
The regression head, or None if not applicable.
|
|
810
|
+
"""
|
|
811
|
+
return self.regression_head
|