bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,910 @@
|
|
|
1
|
+
"""Binary model for yes/no or true/false judgments.
|
|
2
|
+
|
|
3
|
+
Expected architecture: Binary classification with 2-class output.
|
|
4
|
+
Different from 2AFC in semantics - represents absolute judgment rather than choice.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import tempfile
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
from transformers import AutoModel, AutoTokenizer, TrainingArguments
|
|
16
|
+
|
|
17
|
+
from bead.active_learning.config import VarianceComponents
|
|
18
|
+
from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
|
|
19
|
+
from bead.active_learning.models.random_effects import RandomEffectsManager
|
|
20
|
+
from bead.active_learning.trainers.data_collator import MixedEffectsDataCollator
|
|
21
|
+
from bead.active_learning.trainers.dataset_utils import items_to_dataset
|
|
22
|
+
from bead.active_learning.trainers.metrics import compute_binary_metrics
|
|
23
|
+
from bead.active_learning.trainers.model_wrapper import EncoderClassifierWrapper
|
|
24
|
+
from bead.config.active_learning import BinaryModelConfig
|
|
25
|
+
from bead.items.item import Item
|
|
26
|
+
from bead.items.item_template import ItemTemplate, TaskType
|
|
27
|
+
|
|
28
|
+
__all__ = ["BinaryModel"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BinaryModel(ActiveLearningModel):
|
|
32
|
+
"""Model for binary tasks (yes/no, true/false judgments).
|
|
33
|
+
|
|
34
|
+
Uses true binary classification with a single output unit and sigmoid
|
|
35
|
+
activation (logistic regression). This is more efficient than using
|
|
36
|
+
2-class softmax, as we only need to output P(y=1) and compute
|
|
37
|
+
P(y=0) = 1 - P(y=1).
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
config : BinaryModelConfig
|
|
42
|
+
Configuration object containing all model parameters.
|
|
43
|
+
|
|
44
|
+
Attributes
|
|
45
|
+
----------
|
|
46
|
+
config : BinaryModelConfig
|
|
47
|
+
Model configuration.
|
|
48
|
+
tokenizer : AutoTokenizer
|
|
49
|
+
Transformer tokenizer.
|
|
50
|
+
encoder : AutoModel
|
|
51
|
+
Transformer encoder model.
|
|
52
|
+
classifier_head : nn.Sequential
|
|
53
|
+
Classification head (fixed effects head) - outputs single logit.
|
|
54
|
+
num_classes : int
|
|
55
|
+
Number of output units (always 1 for binary classification).
|
|
56
|
+
label_names : list[str] | None
|
|
57
|
+
Label names (e.g., ["no", "yes"] sorted alphabetically).
|
|
58
|
+
positive_class : str | None
|
|
59
|
+
Which label corresponds to y=1 (second alphabetically).
|
|
60
|
+
random_effects : RandomEffectsManager
|
|
61
|
+
Manager for participant-level random effects (scalar biases).
|
|
62
|
+
variance_history : list[VarianceComponents]
|
|
63
|
+
Variance component estimates over training (for diagnostics).
|
|
64
|
+
_is_fitted : bool
|
|
65
|
+
Whether model has been trained.
|
|
66
|
+
|
|
67
|
+
Examples
|
|
68
|
+
--------
|
|
69
|
+
>>> from uuid import uuid4
|
|
70
|
+
>>> from bead.items.item import Item
|
|
71
|
+
>>> from bead.config.active_learning import BinaryModelConfig
|
|
72
|
+
>>> items = [
|
|
73
|
+
... Item(
|
|
74
|
+
... item_template_id=uuid4(),
|
|
75
|
+
... rendered_elements={"text": f"Sentence {i}"}
|
|
76
|
+
... )
|
|
77
|
+
... for i in range(10)
|
|
78
|
+
... ]
|
|
79
|
+
>>> labels = ["yes"] * 5 + ["no"] * 5
|
|
80
|
+
>>> config = BinaryModelConfig( # doctest: +SKIP
|
|
81
|
+
... num_epochs=1, batch_size=2, device="cpu"
|
|
82
|
+
... )
|
|
83
|
+
>>> model = BinaryModel(config=config) # doctest: +SKIP
|
|
84
|
+
>>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
|
|
85
|
+
>>> predictions = model.predict(items[:3], participant_ids=None) # doctest: +SKIP
|
|
86
|
+
|
|
87
|
+
Notes
|
|
88
|
+
-----
|
|
89
|
+
This model uses BCEWithLogitsLoss instead of CrossEntropyLoss, and applies
|
|
90
|
+
sigmoid activation to get probabilities. Random intercepts are scalar values
|
|
91
|
+
(1-dimensional) that shift the logit for each participant.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
config: BinaryModelConfig | None = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Initialize binary model.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
config : BinaryModelConfig | None
|
|
103
|
+
Configuration object. If None, uses default configuration.
|
|
104
|
+
"""
|
|
105
|
+
self.config = config or BinaryModelConfig()
|
|
106
|
+
|
|
107
|
+
# Validate mixed_effects configuration
|
|
108
|
+
super().__init__(self.config)
|
|
109
|
+
|
|
110
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
|
111
|
+
self.encoder = AutoModel.from_pretrained(self.config.model_name)
|
|
112
|
+
|
|
113
|
+
self.num_classes: int = 1 # Single output unit for binary classification
|
|
114
|
+
self.label_names: list[str] | None = None
|
|
115
|
+
self.positive_class: str | None = None # Which label corresponds to 1
|
|
116
|
+
self.classifier_head: nn.Sequential | None = None
|
|
117
|
+
self._is_fitted = False
|
|
118
|
+
|
|
119
|
+
# Initialize random effects manager
|
|
120
|
+
self.random_effects: RandomEffectsManager | None = None
|
|
121
|
+
self.variance_history: list[VarianceComponents] = []
|
|
122
|
+
|
|
123
|
+
self.encoder.to(self.config.device)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def supported_task_types(self) -> list[TaskType]:
|
|
127
|
+
"""Get supported task types.
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
list[TaskType]
|
|
132
|
+
List containing "binary".
|
|
133
|
+
"""
|
|
134
|
+
return ["binary"]
|
|
135
|
+
|
|
136
|
+
def validate_item_compatibility(
|
|
137
|
+
self, item: Item, item_template: ItemTemplate
|
|
138
|
+
) -> None:
|
|
139
|
+
"""Validate item is compatible with binary model.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
item : Item
|
|
144
|
+
Item to validate.
|
|
145
|
+
item_template : ItemTemplate
|
|
146
|
+
Template the item was constructed from.
|
|
147
|
+
|
|
148
|
+
Raises
|
|
149
|
+
------
|
|
150
|
+
ValueError
|
|
151
|
+
If task_type is not "binary".
|
|
152
|
+
"""
|
|
153
|
+
if item_template.task_type != "binary":
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Expected task_type 'binary', got '{item_template.task_type}'"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def _initialize_classifier(self) -> None:
|
|
159
|
+
"""Initialize classification head for binary classification.
|
|
160
|
+
|
|
161
|
+
Outputs a single value (logit) for sigmoid activation.
|
|
162
|
+
"""
|
|
163
|
+
hidden_size = self.encoder.config.hidden_size
|
|
164
|
+
|
|
165
|
+
# Single output unit for binary classification
|
|
166
|
+
if self.config.encoder_mode == "dual_encoder":
|
|
167
|
+
input_size = hidden_size * 2
|
|
168
|
+
else:
|
|
169
|
+
input_size = hidden_size
|
|
170
|
+
|
|
171
|
+
self.classifier_head = nn.Sequential(
|
|
172
|
+
nn.Linear(input_size, 256),
|
|
173
|
+
nn.ReLU(),
|
|
174
|
+
nn.Dropout(0.1),
|
|
175
|
+
nn.Linear(256, 1), # Single output unit
|
|
176
|
+
)
|
|
177
|
+
self.classifier_head.to(self.config.device)
|
|
178
|
+
|
|
179
|
+
def _encode_texts(self, texts: list[str]) -> torch.Tensor:
|
|
180
|
+
"""Encode texts using single encoder.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
texts : list[str]
|
|
185
|
+
Texts to encode.
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
torch.Tensor
|
|
190
|
+
Encoded representations of shape (batch_size, hidden_size).
|
|
191
|
+
"""
|
|
192
|
+
encodings = self.tokenizer(
|
|
193
|
+
texts,
|
|
194
|
+
padding=True,
|
|
195
|
+
truncation=True,
|
|
196
|
+
max_length=self.config.max_length,
|
|
197
|
+
return_tensors="pt",
|
|
198
|
+
)
|
|
199
|
+
encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
|
|
200
|
+
|
|
201
|
+
outputs = self.encoder(**encodings)
|
|
202
|
+
return outputs.last_hidden_state[:, 0, :]
|
|
203
|
+
|
|
204
|
+
def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
|
|
205
|
+
"""Prepare inputs for encoding.
|
|
206
|
+
|
|
207
|
+
For binary tasks, concatenates all rendered elements.
|
|
208
|
+
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
items : list[Item]
|
|
212
|
+
Items to encode.
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
torch.Tensor
|
|
217
|
+
Encoded representations.
|
|
218
|
+
"""
|
|
219
|
+
texts = []
|
|
220
|
+
for item in items:
|
|
221
|
+
# Concatenate all rendered elements
|
|
222
|
+
all_text = " ".join(item.rendered_elements.values())
|
|
223
|
+
texts.append(all_text)
|
|
224
|
+
return self._encode_texts(texts)
|
|
225
|
+
|
|
226
|
+
def _validate_labels(self, labels: list[str]) -> None:
|
|
227
|
+
"""Validate that all labels are valid.
|
|
228
|
+
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
labels : list[str]
|
|
232
|
+
Labels to validate.
|
|
233
|
+
|
|
234
|
+
Raises
|
|
235
|
+
------
|
|
236
|
+
ValueError
|
|
237
|
+
If any label is not in label_names.
|
|
238
|
+
"""
|
|
239
|
+
if self.label_names is None:
|
|
240
|
+
raise ValueError("label_names not initialized")
|
|
241
|
+
|
|
242
|
+
valid_labels = set(self.label_names)
|
|
243
|
+
invalid = [label for label in labels if label not in valid_labels]
|
|
244
|
+
if invalid:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f"Invalid labels found: {set(invalid)}. "
|
|
247
|
+
f"Labels must be one of {valid_labels}."
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def _prepare_training_data(
|
|
251
|
+
self,
|
|
252
|
+
items: list[Item],
|
|
253
|
+
labels: list[str],
|
|
254
|
+
participant_ids: list[str],
|
|
255
|
+
validation_items: list[Item] | None,
|
|
256
|
+
validation_labels: list[str] | None,
|
|
257
|
+
) -> tuple[
|
|
258
|
+
list[Item], list[float], list[str], list[Item] | None, list[float] | None
|
|
259
|
+
]:
|
|
260
|
+
"""Prepare training data for binary model.
|
|
261
|
+
|
|
262
|
+
Parameters
|
|
263
|
+
----------
|
|
264
|
+
items : list[Item]
|
|
265
|
+
Training items.
|
|
266
|
+
labels : list[str]
|
|
267
|
+
Training labels.
|
|
268
|
+
participant_ids : list[str]
|
|
269
|
+
Normalized participant IDs.
|
|
270
|
+
validation_items : list[Item] | None
|
|
271
|
+
Validation items.
|
|
272
|
+
validation_labels : list[str] | None
|
|
273
|
+
Validation labels.
|
|
274
|
+
|
|
275
|
+
Returns
|
|
276
|
+
-------
|
|
277
|
+
tuple[list[Item], list[float], list[str], list[Item] | None, list[float] | None]
|
|
278
|
+
Prepared items, numeric labels, participant_ids, validation_items,
|
|
279
|
+
numeric validation_labels.
|
|
280
|
+
"""
|
|
281
|
+
# Initialize label names
|
|
282
|
+
unique_labels = sorted(set(labels))
|
|
283
|
+
if len(unique_labels) != 2:
|
|
284
|
+
raise ValueError(
|
|
285
|
+
f"Binary classification requires exactly 2 classes, "
|
|
286
|
+
f"got {len(unique_labels)}: {unique_labels}"
|
|
287
|
+
)
|
|
288
|
+
self.label_names = unique_labels
|
|
289
|
+
# Positive class is the second one alphabetically (index 1)
|
|
290
|
+
self.positive_class = unique_labels[1]
|
|
291
|
+
|
|
292
|
+
self._validate_labels(labels)
|
|
293
|
+
self._initialize_classifier()
|
|
294
|
+
|
|
295
|
+
# Convert labels to binary (0/1) floats for HuggingFace Trainer
|
|
296
|
+
# Positive class (second alphabetically) = 1, negative = 0
|
|
297
|
+
y_numeric = [1.0 if label == self.positive_class else 0.0 for label in labels]
|
|
298
|
+
|
|
299
|
+
# Convert validation labels if provided
|
|
300
|
+
val_y_numeric = None
|
|
301
|
+
if validation_items is not None and validation_labels is not None:
|
|
302
|
+
self._validate_labels(validation_labels)
|
|
303
|
+
if len(validation_items) != len(validation_labels):
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"Number of validation items ({len(validation_items)}) "
|
|
306
|
+
f"must match number of validation labels ({len(validation_labels)})"
|
|
307
|
+
)
|
|
308
|
+
val_y_numeric = [
|
|
309
|
+
1.0 if label == self.positive_class else 0.0
|
|
310
|
+
for label in validation_labels
|
|
311
|
+
]
|
|
312
|
+
|
|
313
|
+
return items, y_numeric, participant_ids, validation_items, val_y_numeric
|
|
314
|
+
|
|
315
|
+
def _initialize_random_effects(self, n_classes: int) -> None:
|
|
316
|
+
"""Initialize random effects manager.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
n_classes : int
|
|
321
|
+
Number of classes (1 for binary).
|
|
322
|
+
"""
|
|
323
|
+
self.random_effects = RandomEffectsManager(
|
|
324
|
+
self.config.mixed_effects, n_classes=n_classes
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
def _do_training(
|
|
328
|
+
self,
|
|
329
|
+
items: list[Item],
|
|
330
|
+
labels_numeric: list[float],
|
|
331
|
+
participant_ids: list[str],
|
|
332
|
+
validation_items: list[Item] | None,
|
|
333
|
+
validation_labels_numeric: list[float] | None,
|
|
334
|
+
) -> dict[str, float]:
|
|
335
|
+
"""Perform binary model training.
|
|
336
|
+
|
|
337
|
+
Parameters
|
|
338
|
+
----------
|
|
339
|
+
items : list[Item]
|
|
340
|
+
Training items.
|
|
341
|
+
labels_numeric : list[float]
|
|
342
|
+
Numeric labels (0.0 or 1.0).
|
|
343
|
+
participant_ids : list[str]
|
|
344
|
+
Participant IDs.
|
|
345
|
+
validation_items : list[Item] | None
|
|
346
|
+
Validation items.
|
|
347
|
+
validation_labels_numeric : list[float] | None
|
|
348
|
+
Numeric validation labels.
|
|
349
|
+
|
|
350
|
+
Returns
|
|
351
|
+
-------
|
|
352
|
+
dict[str, float]
|
|
353
|
+
Training metrics.
|
|
354
|
+
"""
|
|
355
|
+
# Convert validation_labels_numeric back to string labels for validation metrics
|
|
356
|
+
validation_labels = None
|
|
357
|
+
if validation_items is not None and validation_labels_numeric is not None:
|
|
358
|
+
validation_labels = [
|
|
359
|
+
self.positive_class if label == 1.0 else self.label_names[0]
|
|
360
|
+
for label in validation_labels_numeric
|
|
361
|
+
]
|
|
362
|
+
|
|
363
|
+
# Use HuggingFace Trainer for fixed and random_intercepts modes
|
|
364
|
+
# random_slopes requires custom loop due to per-participant heads
|
|
365
|
+
use_huggingface_trainer = self.config.mixed_effects.mode in (
|
|
366
|
+
"fixed",
|
|
367
|
+
"random_intercepts",
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
if use_huggingface_trainer:
|
|
371
|
+
metrics = self._train_with_huggingface_trainer(
|
|
372
|
+
items,
|
|
373
|
+
labels_numeric,
|
|
374
|
+
participant_ids,
|
|
375
|
+
validation_items,
|
|
376
|
+
validation_labels,
|
|
377
|
+
)
|
|
378
|
+
else:
|
|
379
|
+
# Use custom training loop for random_slopes
|
|
380
|
+
metrics = self._train_with_custom_loop(
|
|
381
|
+
items,
|
|
382
|
+
labels_numeric,
|
|
383
|
+
participant_ids,
|
|
384
|
+
validation_items,
|
|
385
|
+
validation_labels,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Add validation accuracy if validation data provided
|
|
389
|
+
if validation_items is not None and validation_labels is not None:
|
|
390
|
+
# Validation with placeholder participant_ids for mixed effects
|
|
391
|
+
# Use _do_predict directly since we're in training
|
|
392
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
393
|
+
val_participant_ids = ["_fixed_"] * len(validation_items)
|
|
394
|
+
else:
|
|
395
|
+
val_participant_ids = ["_validation_"] * len(validation_items)
|
|
396
|
+
val_predictions = self._do_predict(validation_items, val_participant_ids)
|
|
397
|
+
val_pred_labels = [p.predicted_class for p in val_predictions]
|
|
398
|
+
val_acc = sum(
|
|
399
|
+
pred == true
|
|
400
|
+
for pred, true in zip(val_pred_labels, validation_labels, strict=True)
|
|
401
|
+
) / len(validation_labels)
|
|
402
|
+
metrics["val_accuracy"] = val_acc
|
|
403
|
+
|
|
404
|
+
return metrics
|
|
405
|
+
|
|
406
|
+
def _train_with_huggingface_trainer(
|
|
407
|
+
self,
|
|
408
|
+
items: list[Item],
|
|
409
|
+
y_numeric: list[float],
|
|
410
|
+
participant_ids: list[str],
|
|
411
|
+
validation_items: list[Item] | None,
|
|
412
|
+
validation_labels: list[str] | None,
|
|
413
|
+
) -> dict[str, float]:
|
|
414
|
+
"""Train using HuggingFace Trainer with mixed effects support.
|
|
415
|
+
|
|
416
|
+
Parameters
|
|
417
|
+
----------
|
|
418
|
+
items : list[Item]
|
|
419
|
+
Training items.
|
|
420
|
+
y_numeric : list[float]
|
|
421
|
+
Numeric labels (0.0 or 1.0).
|
|
422
|
+
participant_ids : list[str]
|
|
423
|
+
Participant IDs.
|
|
424
|
+
validation_items : list[Item] | None
|
|
425
|
+
Validation items.
|
|
426
|
+
validation_labels : list[str] | None
|
|
427
|
+
Validation labels.
|
|
428
|
+
|
|
429
|
+
Returns
|
|
430
|
+
-------
|
|
431
|
+
dict[str, float]
|
|
432
|
+
Training metrics.
|
|
433
|
+
"""
|
|
434
|
+
# Convert items to HuggingFace Dataset
|
|
435
|
+
train_dataset = items_to_dataset(
|
|
436
|
+
items=items,
|
|
437
|
+
labels=y_numeric,
|
|
438
|
+
participant_ids=participant_ids,
|
|
439
|
+
tokenizer=self.tokenizer,
|
|
440
|
+
max_length=self.config.max_length,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Create validation dataset if provided
|
|
444
|
+
eval_dataset = None
|
|
445
|
+
if validation_items is not None and validation_labels is not None:
|
|
446
|
+
val_y_numeric = [
|
|
447
|
+
1.0 if label == self.positive_class else 0.0
|
|
448
|
+
for label in validation_labels
|
|
449
|
+
]
|
|
450
|
+
eval_dataset = items_to_dataset(
|
|
451
|
+
items=validation_items,
|
|
452
|
+
labels=val_y_numeric,
|
|
453
|
+
participant_ids=["_validation_"] * len(validation_items),
|
|
454
|
+
tokenizer=self.tokenizer,
|
|
455
|
+
max_length=self.config.max_length,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Create wrapper model for Trainer
|
|
459
|
+
wrapped_model = EncoderClassifierWrapper(
|
|
460
|
+
encoder=self.encoder, classifier_head=self.classifier_head
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# Create data collator
|
|
464
|
+
data_collator = MixedEffectsDataCollator(tokenizer=self.tokenizer)
|
|
465
|
+
|
|
466
|
+
# Create training arguments with checkpointing
|
|
467
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
468
|
+
checkpoint_dir = Path(tmpdir) / "checkpoints"
|
|
469
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
470
|
+
|
|
471
|
+
training_args = TrainingArguments(
|
|
472
|
+
output_dir=str(checkpoint_dir),
|
|
473
|
+
num_train_epochs=self.config.num_epochs,
|
|
474
|
+
per_device_train_batch_size=self.config.batch_size,
|
|
475
|
+
per_device_eval_batch_size=self.config.batch_size,
|
|
476
|
+
learning_rate=self.config.learning_rate,
|
|
477
|
+
logging_steps=10,
|
|
478
|
+
eval_strategy="epoch" if eval_dataset is not None else "no",
|
|
479
|
+
save_strategy="epoch", # Save checkpoints every epoch
|
|
480
|
+
save_total_limit=1, # Keep only the latest checkpoint
|
|
481
|
+
load_best_model_at_end=False, # Don't auto-load best
|
|
482
|
+
report_to="none", # Disable wandb/tensorboard
|
|
483
|
+
remove_unused_columns=False, # Keep participant_id
|
|
484
|
+
use_cpu=self.config.device == "cpu", # Explicitly use CPU if specified
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Import here to avoid circular import
|
|
488
|
+
from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
|
|
489
|
+
MixedEffectsTrainer,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# Create trainer
|
|
493
|
+
trainer = MixedEffectsTrainer(
|
|
494
|
+
model=wrapped_model,
|
|
495
|
+
args=training_args,
|
|
496
|
+
train_dataset=train_dataset,
|
|
497
|
+
eval_dataset=eval_dataset,
|
|
498
|
+
data_collator=data_collator,
|
|
499
|
+
tokenizer=self.tokenizer,
|
|
500
|
+
random_effects_manager=self.random_effects,
|
|
501
|
+
compute_metrics=compute_binary_metrics,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# Train (checkpoints are saved automatically by Trainer)
|
|
505
|
+
train_result = trainer.train()
|
|
506
|
+
|
|
507
|
+
# Get training metrics using evaluate (Trainer computes metrics during eval)
|
|
508
|
+
train_metrics = trainer.evaluate(eval_dataset=train_dataset)
|
|
509
|
+
# Trainer prefixes eval metrics with "eval_"
|
|
510
|
+
metrics: dict[str, float] = {
|
|
511
|
+
"train_loss": float(train_result.training_loss),
|
|
512
|
+
"train_accuracy": train_metrics.get("eval_accuracy", 0.0),
|
|
513
|
+
"train_precision": train_metrics.get("eval_precision", 0.0),
|
|
514
|
+
"train_recall": train_metrics.get("eval_recall", 0.0),
|
|
515
|
+
"train_f1": train_metrics.get("eval_f1", 0.0),
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
# Get validation metrics if eval_dataset was provided
|
|
519
|
+
if eval_dataset is not None:
|
|
520
|
+
val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
|
521
|
+
metrics.update(
|
|
522
|
+
{
|
|
523
|
+
"val_accuracy": val_metrics.get("eval_accuracy", 0.0),
|
|
524
|
+
"val_precision": val_metrics.get("eval_precision", 0.0),
|
|
525
|
+
"val_recall": val_metrics.get("eval_recall", 0.0),
|
|
526
|
+
"val_f1": val_metrics.get("eval_f1", 0.0),
|
|
527
|
+
}
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
# Estimate variance components
|
|
531
|
+
if self.config.mixed_effects.estimate_variance_components:
|
|
532
|
+
var_comps = self.random_effects.estimate_variance_components()
|
|
533
|
+
if var_comps:
|
|
534
|
+
var_comp = var_comps.get("mu") or var_comps.get("slopes")
|
|
535
|
+
if var_comp:
|
|
536
|
+
self.variance_history.append(var_comp)
|
|
537
|
+
metrics["participant_variance"] = var_comp.variance
|
|
538
|
+
metrics["n_participants"] = var_comp.n_groups
|
|
539
|
+
|
|
540
|
+
# Validation metrics (already computed by Trainer if eval_dataset provided)
|
|
541
|
+
if eval_dataset is not None:
|
|
542
|
+
val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
|
543
|
+
metrics.update(
|
|
544
|
+
{
|
|
545
|
+
"val_accuracy": val_metrics.get("eval_accuracy", 0.0),
|
|
546
|
+
"val_precision": val_metrics.get("eval_precision", 0.0),
|
|
547
|
+
"val_recall": val_metrics.get("eval_recall", 0.0),
|
|
548
|
+
"val_f1": val_metrics.get("eval_f1", 0.0),
|
|
549
|
+
}
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
return metrics
|
|
553
|
+
|
|
554
|
+
def _train_with_custom_loop(
|
|
555
|
+
self,
|
|
556
|
+
items: list[Item],
|
|
557
|
+
y_numeric: list[float],
|
|
558
|
+
participant_ids: list[str],
|
|
559
|
+
validation_items: list[Item] | None,
|
|
560
|
+
validation_labels: list[str] | None,
|
|
561
|
+
) -> dict[str, float]:
|
|
562
|
+
"""Train using custom training loop (for random_slopes mode).
|
|
563
|
+
|
|
564
|
+
Parameters
|
|
565
|
+
----------
|
|
566
|
+
items : list[Item]
|
|
567
|
+
Training items.
|
|
568
|
+
y_numeric : list[float]
|
|
569
|
+
Numeric labels (0.0 or 1.0).
|
|
570
|
+
participant_ids : list[str]
|
|
571
|
+
Participant IDs.
|
|
572
|
+
validation_items : list[Item] | None
|
|
573
|
+
Validation items.
|
|
574
|
+
validation_labels : list[str] | None
|
|
575
|
+
Validation labels.
|
|
576
|
+
|
|
577
|
+
Returns
|
|
578
|
+
-------
|
|
579
|
+
dict[str, float]
|
|
580
|
+
Training metrics.
|
|
581
|
+
"""
|
|
582
|
+
# Convert to tensor
|
|
583
|
+
y = torch.tensor(y_numeric, dtype=torch.float, device=self.config.device)
|
|
584
|
+
|
|
585
|
+
# Build optimizer parameters
|
|
586
|
+
params_to_optimize = list(self.encoder.parameters()) + list(
|
|
587
|
+
self.classifier_head.parameters()
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Add random effects parameters (for random_slopes)
|
|
591
|
+
if self.config.mixed_effects.mode == "random_slopes":
|
|
592
|
+
for head in self.random_effects.slopes.values():
|
|
593
|
+
params_to_optimize.extend(head.parameters())
|
|
594
|
+
|
|
595
|
+
optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
|
|
596
|
+
criterion = nn.BCEWithLogitsLoss()
|
|
597
|
+
|
|
598
|
+
self.encoder.train()
|
|
599
|
+
self.classifier_head.train()
|
|
600
|
+
|
|
601
|
+
for _epoch in range(self.config.num_epochs):
|
|
602
|
+
n_batches = (
|
|
603
|
+
len(items) + self.config.batch_size - 1
|
|
604
|
+
) // self.config.batch_size
|
|
605
|
+
epoch_loss = 0.0
|
|
606
|
+
epoch_correct = 0
|
|
607
|
+
|
|
608
|
+
for i in range(n_batches):
|
|
609
|
+
start_idx = i * self.config.batch_size
|
|
610
|
+
end_idx = min(start_idx + self.config.batch_size, len(items))
|
|
611
|
+
|
|
612
|
+
batch_items = items[start_idx:end_idx]
|
|
613
|
+
batch_labels = y[start_idx:end_idx]
|
|
614
|
+
batch_participant_ids = participant_ids[start_idx:end_idx]
|
|
615
|
+
|
|
616
|
+
embeddings = self._prepare_inputs(batch_items)
|
|
617
|
+
|
|
618
|
+
# Random slopes: per-participant head
|
|
619
|
+
logits_list = []
|
|
620
|
+
for j, pid in enumerate(batch_participant_ids):
|
|
621
|
+
participant_head = self.random_effects.get_slopes(
|
|
622
|
+
pid,
|
|
623
|
+
fixed_head=self.classifier_head,
|
|
624
|
+
create_if_missing=True,
|
|
625
|
+
)
|
|
626
|
+
logits_j = participant_head(embeddings[j : j + 1]).squeeze()
|
|
627
|
+
logits_list.append(logits_j)
|
|
628
|
+
logits = torch.stack(logits_list)
|
|
629
|
+
|
|
630
|
+
# Data loss + prior regularization
|
|
631
|
+
loss_bce = criterion(logits, batch_labels)
|
|
632
|
+
loss_prior = self.random_effects.compute_prior_loss()
|
|
633
|
+
loss = loss_bce + loss_prior
|
|
634
|
+
|
|
635
|
+
optimizer.zero_grad()
|
|
636
|
+
loss.backward()
|
|
637
|
+
optimizer.step()
|
|
638
|
+
|
|
639
|
+
epoch_loss += loss.item()
|
|
640
|
+
predictions = (torch.sigmoid(logits) > 0.5).float()
|
|
641
|
+
epoch_correct += (predictions == batch_labels).sum().item()
|
|
642
|
+
|
|
643
|
+
epoch_acc = epoch_correct / len(items)
|
|
644
|
+
epoch_loss = epoch_loss / n_batches
|
|
645
|
+
|
|
646
|
+
metrics: dict[str, float] = {
|
|
647
|
+
"train_accuracy": epoch_acc,
|
|
648
|
+
"train_loss": epoch_loss,
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
# Estimate variance components
|
|
652
|
+
if self.config.mixed_effects.estimate_variance_components:
|
|
653
|
+
var_comps = self.random_effects.estimate_variance_components()
|
|
654
|
+
if var_comps:
|
|
655
|
+
var_comp = var_comps.get("mu") or var_comps.get("slopes")
|
|
656
|
+
if var_comp:
|
|
657
|
+
self.variance_history.append(var_comp)
|
|
658
|
+
metrics["participant_variance"] = var_comp.variance
|
|
659
|
+
metrics["n_participants"] = var_comp.n_groups
|
|
660
|
+
|
|
661
|
+
# Validation
|
|
662
|
+
if validation_items is not None and validation_labels is not None:
|
|
663
|
+
val_predictions = self.predict(
|
|
664
|
+
validation_items,
|
|
665
|
+
participant_ids=["_validation_"] * len(validation_items),
|
|
666
|
+
)
|
|
667
|
+
val_pred_labels = [p.predicted_class for p in val_predictions]
|
|
668
|
+
val_acc = sum(
|
|
669
|
+
pred == true
|
|
670
|
+
for pred, true in zip(val_pred_labels, validation_labels, strict=True)
|
|
671
|
+
) / len(validation_labels)
|
|
672
|
+
metrics["val_accuracy"] = val_acc
|
|
673
|
+
|
|
674
|
+
return metrics
|
|
675
|
+
|
|
676
|
+
def _do_predict(
|
|
677
|
+
self, items: list[Item], participant_ids: list[str]
|
|
678
|
+
) -> list[ModelPrediction]:
|
|
679
|
+
"""Perform binary model prediction.
|
|
680
|
+
|
|
681
|
+
Parameters
|
|
682
|
+
----------
|
|
683
|
+
items : list[Item]
|
|
684
|
+
Items to predict.
|
|
685
|
+
participant_ids : list[str]
|
|
686
|
+
Normalized participant IDs.
|
|
687
|
+
|
|
688
|
+
Returns
|
|
689
|
+
-------
|
|
690
|
+
list[ModelPrediction]
|
|
691
|
+
Predictions.
|
|
692
|
+
"""
|
|
693
|
+
self.encoder.eval()
|
|
694
|
+
self.classifier_head.eval()
|
|
695
|
+
|
|
696
|
+
with torch.no_grad():
|
|
697
|
+
embeddings = self._prepare_inputs(items)
|
|
698
|
+
|
|
699
|
+
# Forward pass depends on mixed effects mode
|
|
700
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
701
|
+
logits = self.classifier_head(embeddings).squeeze(1) # (n_items,)
|
|
702
|
+
|
|
703
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
704
|
+
logits = self.classifier_head(embeddings).squeeze(1) # (n_items,)
|
|
705
|
+
for i, pid in enumerate(participant_ids):
|
|
706
|
+
# Unknown participants: use prior mean (zero bias)
|
|
707
|
+
bias = self.random_effects.get_intercepts(
|
|
708
|
+
pid,
|
|
709
|
+
n_classes=self.num_classes,
|
|
710
|
+
param_name="mu",
|
|
711
|
+
create_if_missing=False,
|
|
712
|
+
)
|
|
713
|
+
logits[i] = logits[i] + bias.item()
|
|
714
|
+
|
|
715
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
716
|
+
logits_list = []
|
|
717
|
+
for i, pid in enumerate(participant_ids):
|
|
718
|
+
# Unknown participants: use fixed head
|
|
719
|
+
participant_head = self.random_effects.get_slopes(
|
|
720
|
+
pid, fixed_head=self.classifier_head, create_if_missing=False
|
|
721
|
+
)
|
|
722
|
+
logits_i = participant_head(embeddings[i : i + 1]).squeeze()
|
|
723
|
+
logits_list.append(logits_i)
|
|
724
|
+
logits = torch.stack(logits_list)
|
|
725
|
+
|
|
726
|
+
# Compute probabilities using sigmoid
|
|
727
|
+
proba_positive = torch.sigmoid(logits).cpu().numpy() # P(y=1)
|
|
728
|
+
pred_is_positive = proba_positive > 0.5
|
|
729
|
+
|
|
730
|
+
predictions = []
|
|
731
|
+
for i, item in enumerate(items):
|
|
732
|
+
# Determine predicted class
|
|
733
|
+
if pred_is_positive[i]:
|
|
734
|
+
pred_label = self.positive_class
|
|
735
|
+
else:
|
|
736
|
+
pred_label = self.label_names[0]
|
|
737
|
+
|
|
738
|
+
# Build probability dict: {negative_class: p0, positive_class: p1}
|
|
739
|
+
p1 = float(proba_positive[i])
|
|
740
|
+
p0 = 1.0 - p1
|
|
741
|
+
prob_dict = {
|
|
742
|
+
self.label_names[0]: p0, # Negative class (first alphabetically)
|
|
743
|
+
self.positive_class: p1, # Positive class (second alphabetically)
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
predictions.append(
|
|
747
|
+
ModelPrediction(
|
|
748
|
+
item_id=str(item.id),
|
|
749
|
+
probabilities=prob_dict,
|
|
750
|
+
predicted_class=pred_label,
|
|
751
|
+
confidence=max(p0, p1),
|
|
752
|
+
)
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
return predictions
|
|
756
|
+
|
|
757
|
+
def _do_predict_proba(
|
|
758
|
+
self, items: list[Item], participant_ids: list[str]
|
|
759
|
+
) -> np.ndarray:
|
|
760
|
+
"""Perform binary model probability prediction.
|
|
761
|
+
|
|
762
|
+
Parameters
|
|
763
|
+
----------
|
|
764
|
+
items : list[Item]
|
|
765
|
+
Items to predict.
|
|
766
|
+
participant_ids : list[str]
|
|
767
|
+
Normalized participant IDs.
|
|
768
|
+
|
|
769
|
+
Returns
|
|
770
|
+
-------
|
|
771
|
+
np.ndarray
|
|
772
|
+
Probability array of shape (n_items, 2).
|
|
773
|
+
"""
|
|
774
|
+
self.encoder.eval()
|
|
775
|
+
self.classifier_head.eval()
|
|
776
|
+
|
|
777
|
+
with torch.no_grad():
|
|
778
|
+
embeddings = self._prepare_inputs(items)
|
|
779
|
+
|
|
780
|
+
# Forward pass depends on mixed effects mode
|
|
781
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
782
|
+
logits = self.classifier_head(embeddings).squeeze(1)
|
|
783
|
+
|
|
784
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
785
|
+
logits = self.classifier_head(embeddings).squeeze(1)
|
|
786
|
+
for i, pid in enumerate(participant_ids):
|
|
787
|
+
bias = self.random_effects.get_intercepts(
|
|
788
|
+
pid,
|
|
789
|
+
n_classes=self.num_classes,
|
|
790
|
+
param_name="mu",
|
|
791
|
+
create_if_missing=False,
|
|
792
|
+
)
|
|
793
|
+
logits[i] = logits[i] + bias.item()
|
|
794
|
+
|
|
795
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
796
|
+
logits_list = []
|
|
797
|
+
for i, pid in enumerate(participant_ids):
|
|
798
|
+
participant_head = self.random_effects.get_slopes(
|
|
799
|
+
pid, fixed_head=self.classifier_head, create_if_missing=False
|
|
800
|
+
)
|
|
801
|
+
logits_i = participant_head(embeddings[i : i + 1]).squeeze()
|
|
802
|
+
logits_list.append(logits_i)
|
|
803
|
+
logits = torch.stack(logits_list)
|
|
804
|
+
|
|
805
|
+
# Compute probabilities using sigmoid
|
|
806
|
+
proba_positive = torch.sigmoid(logits).cpu().numpy() # P(y=1)
|
|
807
|
+
|
|
808
|
+
# Return (n_items, 2) array: [P(negative), P(positive)]
|
|
809
|
+
proba = np.stack([1.0 - proba_positive, proba_positive], axis=1)
|
|
810
|
+
|
|
811
|
+
return proba
|
|
812
|
+
|
|
813
|
+
def _get_save_state(self) -> dict[str, object]:
|
|
814
|
+
"""Get model-specific state to save.
|
|
815
|
+
|
|
816
|
+
Returns
|
|
817
|
+
-------
|
|
818
|
+
dict[str, object]
|
|
819
|
+
State dictionary.
|
|
820
|
+
"""
|
|
821
|
+
return {
|
|
822
|
+
"num_classes": self.num_classes,
|
|
823
|
+
"label_names": self.label_names,
|
|
824
|
+
"positive_class": self.positive_class,
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
def _save_model_components(self, save_path: Path) -> None:
|
|
828
|
+
"""Save model-specific components.
|
|
829
|
+
|
|
830
|
+
Parameters
|
|
831
|
+
----------
|
|
832
|
+
save_path : Path
|
|
833
|
+
Directory to save to.
|
|
834
|
+
"""
|
|
835
|
+
self.encoder.save_pretrained(save_path / "encoder")
|
|
836
|
+
self.tokenizer.save_pretrained(save_path / "encoder")
|
|
837
|
+
|
|
838
|
+
torch.save(
|
|
839
|
+
self.classifier_head.state_dict(),
|
|
840
|
+
save_path / "classifier_head.pt",
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
def _restore_training_state(self, config_dict: dict[str, object]) -> None:
|
|
844
|
+
"""Restore model-specific training state.
|
|
845
|
+
|
|
846
|
+
Parameters
|
|
847
|
+
----------
|
|
848
|
+
config_dict : dict[str, object]
|
|
849
|
+
Configuration dictionary with training state.
|
|
850
|
+
"""
|
|
851
|
+
self.num_classes = config_dict.pop("num_classes")
|
|
852
|
+
self.label_names = config_dict.pop("label_names")
|
|
853
|
+
self.positive_class = config_dict.pop("positive_class")
|
|
854
|
+
|
|
855
|
+
def _load_model_components(self, load_path: Path) -> None:
|
|
856
|
+
"""Load model-specific components.
|
|
857
|
+
|
|
858
|
+
Parameters
|
|
859
|
+
----------
|
|
860
|
+
load_path : Path
|
|
861
|
+
Directory to load from.
|
|
862
|
+
"""
|
|
863
|
+
# Load config.json to reconstruct config
|
|
864
|
+
with open(load_path / "config.json") as f:
|
|
865
|
+
import json # noqa: PLC0415
|
|
866
|
+
|
|
867
|
+
config_dict = json.load(f)
|
|
868
|
+
|
|
869
|
+
# Reconstruct MixedEffectsConfig if needed
|
|
870
|
+
if "mixed_effects" in config_dict and isinstance(
|
|
871
|
+
config_dict["mixed_effects"], dict
|
|
872
|
+
):
|
|
873
|
+
from bead.active_learning.config import MixedEffectsConfig # noqa: PLC0415
|
|
874
|
+
|
|
875
|
+
config_dict["mixed_effects"] = MixedEffectsConfig(
|
|
876
|
+
**config_dict["mixed_effects"]
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
self.config = BinaryModelConfig(**config_dict)
|
|
880
|
+
|
|
881
|
+
self.encoder = AutoModel.from_pretrained(load_path / "encoder")
|
|
882
|
+
self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
|
|
883
|
+
|
|
884
|
+
self._initialize_classifier()
|
|
885
|
+
self.classifier_head.load_state_dict(
|
|
886
|
+
torch.load(
|
|
887
|
+
load_path / "classifier_head.pt", map_location=self.config.device
|
|
888
|
+
)
|
|
889
|
+
)
|
|
890
|
+
self.classifier_head.to(self.config.device)
|
|
891
|
+
|
|
892
|
+
def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
|
|
893
|
+
"""Get fixed head for random effects loading.
|
|
894
|
+
|
|
895
|
+
Returns
|
|
896
|
+
-------
|
|
897
|
+
nn.Module | None
|
|
898
|
+
Fixed head module.
|
|
899
|
+
"""
|
|
900
|
+
return self.classifier_head
|
|
901
|
+
|
|
902
|
+
def _get_n_classes_for_random_effects(self) -> int:
|
|
903
|
+
"""Get number of classes for random effects initialization.
|
|
904
|
+
|
|
905
|
+
Returns
|
|
906
|
+
-------
|
|
907
|
+
int
|
|
908
|
+
Number of classes (1 for binary).
|
|
909
|
+
"""
|
|
910
|
+
return self.num_classes
|