bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,943 @@
|
|
|
1
|
+
"""Model for categorical tasks (unordered N-class classification)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import tempfile
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from transformers import AutoModel, AutoTokenizer, TrainingArguments
|
|
12
|
+
|
|
13
|
+
from bead.active_learning.config import VarianceComponents
|
|
14
|
+
from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
|
|
15
|
+
from bead.active_learning.models.random_effects import RandomEffectsManager
|
|
16
|
+
from bead.active_learning.trainers.data_collator import MixedEffectsDataCollator
|
|
17
|
+
from bead.active_learning.trainers.dataset_utils import items_to_dataset
|
|
18
|
+
from bead.active_learning.trainers.metrics import compute_multiclass_metrics
|
|
19
|
+
from bead.active_learning.trainers.model_wrapper import EncoderClassifierWrapper
|
|
20
|
+
from bead.config.active_learning import CategoricalModelConfig
|
|
21
|
+
from bead.items.item import Item
|
|
22
|
+
from bead.items.item_template import ItemTemplate, TaskType
|
|
23
|
+
|
|
24
|
+
__all__ = ["CategoricalModel"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CategoricalModel(ActiveLearningModel):
|
|
28
|
+
"""Model for categorical tasks with N unordered categories.
|
|
29
|
+
|
|
30
|
+
Supports N-class classification (N ≥ 2) using any HuggingFace transformer
|
|
31
|
+
model. Provides two encoding strategies: single encoder (concatenate
|
|
32
|
+
categories) or dual encoder (separate embeddings).
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
config : CategoricalModelConfig
|
|
37
|
+
Configuration object containing all model parameters.
|
|
38
|
+
|
|
39
|
+
Attributes
|
|
40
|
+
----------
|
|
41
|
+
config : CategoricalModelConfig
|
|
42
|
+
Model configuration.
|
|
43
|
+
tokenizer : AutoTokenizer
|
|
44
|
+
Transformer tokenizer.
|
|
45
|
+
encoder : AutoModel
|
|
46
|
+
Transformer encoder model.
|
|
47
|
+
classifier_head : nn.Sequential
|
|
48
|
+
Classification head (fixed effects head).
|
|
49
|
+
num_classes : int | None
|
|
50
|
+
Number of classes (inferred from training data).
|
|
51
|
+
category_names : list[str] | None
|
|
52
|
+
Category names (e.g., ["entailment", "neutral", "contradiction"]).
|
|
53
|
+
random_effects : RandomEffectsManager
|
|
54
|
+
Manager for participant-level random effects.
|
|
55
|
+
variance_history : list[VarianceComponents]
|
|
56
|
+
Variance component estimates over training (for diagnostics).
|
|
57
|
+
_is_fitted : bool
|
|
58
|
+
Whether model has been trained.
|
|
59
|
+
|
|
60
|
+
Examples
|
|
61
|
+
--------
|
|
62
|
+
>>> from uuid import uuid4
|
|
63
|
+
>>> from bead.items.item import Item
|
|
64
|
+
>>> from bead.config.active_learning import CategoricalModelConfig
|
|
65
|
+
>>> items = [
|
|
66
|
+
... Item(
|
|
67
|
+
... item_template_id=uuid4(),
|
|
68
|
+
... rendered_elements={"premise": "sent A", "hypothesis": "sent B"}
|
|
69
|
+
... )
|
|
70
|
+
... for _ in range(10)
|
|
71
|
+
... ]
|
|
72
|
+
>>> labels = ["entailment"] * 5 + ["contradiction"] * 5
|
|
73
|
+
>>> config = CategoricalModelConfig( # doctest: +SKIP
|
|
74
|
+
... num_epochs=1, batch_size=2, device="cpu"
|
|
75
|
+
... )
|
|
76
|
+
>>> model = CategoricalModel(config=config) # doctest: +SKIP
|
|
77
|
+
>>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
|
|
78
|
+
>>> predictions = model.predict(items[:3], participant_ids=None) # doctest: +SKIP
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
config: CategoricalModelConfig | None = None,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Initialize categorical model.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
config : CategoricalModelConfig | None
|
|
90
|
+
Configuration object. If None, uses default configuration.
|
|
91
|
+
"""
|
|
92
|
+
self.config = config or CategoricalModelConfig()
|
|
93
|
+
|
|
94
|
+
# Validate mixed_effects configuration
|
|
95
|
+
super().__init__(self.config)
|
|
96
|
+
|
|
97
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
|
98
|
+
self.encoder = AutoModel.from_pretrained(self.config.model_name)
|
|
99
|
+
|
|
100
|
+
self.num_classes: int | None = None
|
|
101
|
+
self.category_names: list[str] | None = None
|
|
102
|
+
self.classifier_head: nn.Sequential | None = None
|
|
103
|
+
self._is_fitted = False
|
|
104
|
+
|
|
105
|
+
# Initialize random effects manager
|
|
106
|
+
self.random_effects: RandomEffectsManager | None = None
|
|
107
|
+
self.variance_history: list[VarianceComponents] = []
|
|
108
|
+
|
|
109
|
+
self.encoder.to(self.config.device)
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def supported_task_types(self) -> list[TaskType]:
|
|
113
|
+
"""Get supported task types.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
list[TaskType]
|
|
118
|
+
List containing "categorical".
|
|
119
|
+
"""
|
|
120
|
+
return ["categorical"]
|
|
121
|
+
|
|
122
|
+
def validate_item_compatibility(
|
|
123
|
+
self, item: Item, item_template: ItemTemplate
|
|
124
|
+
) -> None:
|
|
125
|
+
"""Validate item is compatible with categorical model.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
item : Item
|
|
130
|
+
Item to validate.
|
|
131
|
+
item_template : ItemTemplate
|
|
132
|
+
Template the item was constructed from.
|
|
133
|
+
|
|
134
|
+
Raises
|
|
135
|
+
------
|
|
136
|
+
ValueError
|
|
137
|
+
If task_type is not "categorical".
|
|
138
|
+
"""
|
|
139
|
+
if item_template.task_type != "categorical":
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"Expected task_type 'categorical', got '{item_template.task_type}'"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def _initialize_classifier(self, num_classes: int) -> None:
|
|
145
|
+
"""Initialize classification head for given number of classes.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
num_classes : int
|
|
150
|
+
Number of output classes.
|
|
151
|
+
"""
|
|
152
|
+
hidden_size = self.encoder.config.hidden_size
|
|
153
|
+
|
|
154
|
+
if self.config.encoder_mode == "dual_encoder":
|
|
155
|
+
input_size = hidden_size * num_classes
|
|
156
|
+
else:
|
|
157
|
+
input_size = hidden_size
|
|
158
|
+
|
|
159
|
+
self.classifier_head = nn.Sequential(
|
|
160
|
+
nn.Linear(input_size, 256),
|
|
161
|
+
nn.ReLU(),
|
|
162
|
+
nn.Dropout(0.1),
|
|
163
|
+
nn.Linear(256, num_classes),
|
|
164
|
+
)
|
|
165
|
+
self.classifier_head.to(self.config.device)
|
|
166
|
+
|
|
167
|
+
def _encode_single(self, texts: list[str]) -> torch.Tensor:
|
|
168
|
+
"""Encode texts using single encoder strategy.
|
|
169
|
+
|
|
170
|
+
Concatenates all category texts with [SEP] tokens and encodes once.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
texts : list[str]
|
|
175
|
+
List of concatenated category texts for each item.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
torch.Tensor
|
|
180
|
+
Encoded representations of shape (batch_size, hidden_size).
|
|
181
|
+
"""
|
|
182
|
+
encodings = self.tokenizer(
|
|
183
|
+
texts,
|
|
184
|
+
padding=True,
|
|
185
|
+
truncation=True,
|
|
186
|
+
max_length=self.config.max_length,
|
|
187
|
+
return_tensors="pt",
|
|
188
|
+
)
|
|
189
|
+
encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
|
|
190
|
+
|
|
191
|
+
outputs = self.encoder(**encodings)
|
|
192
|
+
return outputs.last_hidden_state[:, 0, :]
|
|
193
|
+
|
|
194
|
+
def _encode_dual(self, categories_per_item: list[list[str]]) -> torch.Tensor:
|
|
195
|
+
"""Encode texts using dual encoder strategy.
|
|
196
|
+
|
|
197
|
+
Encodes each category separately and concatenates embeddings.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
categories_per_item : list[list[str]]
|
|
202
|
+
List of category lists. Each inner list contains category texts
|
|
203
|
+
for one item.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
torch.Tensor
|
|
208
|
+
Concatenated encodings of shape (batch_size, hidden_size * num_categories).
|
|
209
|
+
"""
|
|
210
|
+
all_embeddings = []
|
|
211
|
+
|
|
212
|
+
for categories in categories_per_item:
|
|
213
|
+
category_embeddings = []
|
|
214
|
+
for category_text in categories:
|
|
215
|
+
encodings = self.tokenizer(
|
|
216
|
+
[category_text],
|
|
217
|
+
padding=True,
|
|
218
|
+
truncation=True,
|
|
219
|
+
max_length=self.config.max_length,
|
|
220
|
+
return_tensors="pt",
|
|
221
|
+
)
|
|
222
|
+
encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
|
|
223
|
+
|
|
224
|
+
outputs = self.encoder(**encodings)
|
|
225
|
+
cls_embedding = outputs.last_hidden_state[0, 0, :]
|
|
226
|
+
category_embeddings.append(cls_embedding)
|
|
227
|
+
|
|
228
|
+
concatenated = torch.cat(category_embeddings, dim=0)
|
|
229
|
+
all_embeddings.append(concatenated)
|
|
230
|
+
|
|
231
|
+
return torch.stack(all_embeddings)
|
|
232
|
+
|
|
233
|
+
def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
|
|
234
|
+
"""Prepare inputs for encoding based on encoder mode.
|
|
235
|
+
|
|
236
|
+
For categorical tasks, concatenates all rendered elements.
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
items : list[Item]
|
|
241
|
+
Items to encode.
|
|
242
|
+
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
torch.Tensor
|
|
246
|
+
Encoded representations.
|
|
247
|
+
"""
|
|
248
|
+
if self.category_names is None:
|
|
249
|
+
raise ValueError("Model not initialized. Call train() first.")
|
|
250
|
+
|
|
251
|
+
if self.config.encoder_mode == "single_encoder":
|
|
252
|
+
texts = []
|
|
253
|
+
for item in items:
|
|
254
|
+
# Concatenate all rendered elements
|
|
255
|
+
all_text = " ".join(item.rendered_elements.values())
|
|
256
|
+
texts.append(all_text)
|
|
257
|
+
return self._encode_single(texts)
|
|
258
|
+
else:
|
|
259
|
+
categories_per_item = []
|
|
260
|
+
for item in items:
|
|
261
|
+
category_texts = list(item.rendered_elements.values())
|
|
262
|
+
categories_per_item.append(category_texts)
|
|
263
|
+
return self._encode_dual(categories_per_item)
|
|
264
|
+
|
|
265
|
+
def _validate_labels(self, labels: list[str]) -> None:
|
|
266
|
+
"""Validate that all labels are valid category names.
|
|
267
|
+
|
|
268
|
+
Parameters
|
|
269
|
+
----------
|
|
270
|
+
labels : list[str]
|
|
271
|
+
Labels to validate.
|
|
272
|
+
|
|
273
|
+
Raises
|
|
274
|
+
------
|
|
275
|
+
ValueError
|
|
276
|
+
If any label is not in category_names.
|
|
277
|
+
"""
|
|
278
|
+
if self.category_names is None:
|
|
279
|
+
raise ValueError("category_names not initialized")
|
|
280
|
+
|
|
281
|
+
valid_labels = set(self.category_names)
|
|
282
|
+
invalid = [label for label in labels if label not in valid_labels]
|
|
283
|
+
if invalid:
|
|
284
|
+
raise ValueError(
|
|
285
|
+
f"Invalid labels found: {set(invalid)}. "
|
|
286
|
+
f"Labels must be one of {valid_labels}."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def _prepare_training_data(
|
|
290
|
+
self,
|
|
291
|
+
items: list[Item],
|
|
292
|
+
labels: list[str],
|
|
293
|
+
participant_ids: list[str],
|
|
294
|
+
validation_items: list[Item] | None,
|
|
295
|
+
validation_labels: list[str] | None,
|
|
296
|
+
) -> tuple[list[Item], list[int], list[str], list[Item] | None, list[int] | None]:
|
|
297
|
+
"""Prepare training data for categorical model.
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
items : list[Item]
|
|
302
|
+
Training items.
|
|
303
|
+
labels : list[str]
|
|
304
|
+
Training labels.
|
|
305
|
+
participant_ids : list[str]
|
|
306
|
+
Normalized participant IDs.
|
|
307
|
+
validation_items : list[Item] | None
|
|
308
|
+
Validation items.
|
|
309
|
+
validation_labels : list[str] | None
|
|
310
|
+
Validation labels.
|
|
311
|
+
|
|
312
|
+
Returns
|
|
313
|
+
-------
|
|
314
|
+
tuple[list[Item], list[int], list[str], list[Item] | None, list[int] | None]
|
|
315
|
+
Prepared items, numeric labels, participant_ids, validation_items,
|
|
316
|
+
numeric validation_labels.
|
|
317
|
+
"""
|
|
318
|
+
unique_labels = sorted(set(labels))
|
|
319
|
+
self.num_classes = len(unique_labels)
|
|
320
|
+
self.category_names = unique_labels
|
|
321
|
+
|
|
322
|
+
self._validate_labels(labels)
|
|
323
|
+
self._initialize_classifier(self.num_classes)
|
|
324
|
+
|
|
325
|
+
label_to_idx = {label: idx for idx, label in enumerate(self.category_names)}
|
|
326
|
+
y_numeric = [label_to_idx[label] for label in labels]
|
|
327
|
+
|
|
328
|
+
# Convert validation labels if provided
|
|
329
|
+
val_y_numeric = None
|
|
330
|
+
if validation_items is not None and validation_labels is not None:
|
|
331
|
+
self._validate_labels(validation_labels)
|
|
332
|
+
if len(validation_items) != len(validation_labels):
|
|
333
|
+
raise ValueError(
|
|
334
|
+
f"Number of validation items ({len(validation_items)}) "
|
|
335
|
+
f"must match number of validation labels ({len(validation_labels)})"
|
|
336
|
+
)
|
|
337
|
+
val_y_numeric = [label_to_idx[label] for label in validation_labels]
|
|
338
|
+
|
|
339
|
+
return items, y_numeric, participant_ids, validation_items, val_y_numeric
|
|
340
|
+
|
|
341
|
+
def _initialize_random_effects(self, n_classes: int) -> None:
|
|
342
|
+
"""Initialize random effects manager.
|
|
343
|
+
|
|
344
|
+
Parameters
|
|
345
|
+
----------
|
|
346
|
+
n_classes : int
|
|
347
|
+
Number of classes.
|
|
348
|
+
"""
|
|
349
|
+
self.random_effects = RandomEffectsManager(
|
|
350
|
+
self.config.mixed_effects, n_classes=n_classes
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
def _do_training(
|
|
354
|
+
self,
|
|
355
|
+
items: list[Item],
|
|
356
|
+
labels_numeric: list[int],
|
|
357
|
+
participant_ids: list[str],
|
|
358
|
+
validation_items: list[Item] | None,
|
|
359
|
+
validation_labels_numeric: list[int] | None,
|
|
360
|
+
) -> dict[str, float]:
|
|
361
|
+
"""Perform categorical model training.
|
|
362
|
+
|
|
363
|
+
Parameters
|
|
364
|
+
----------
|
|
365
|
+
items : list[Item]
|
|
366
|
+
Training items.
|
|
367
|
+
labels_numeric : list[int]
|
|
368
|
+
Numeric labels (class indices).
|
|
369
|
+
participant_ids : list[str]
|
|
370
|
+
Participant IDs.
|
|
371
|
+
validation_items : list[Item] | None
|
|
372
|
+
Validation items.
|
|
373
|
+
validation_labels_numeric : list[int] | None
|
|
374
|
+
Numeric validation labels.
|
|
375
|
+
|
|
376
|
+
Returns
|
|
377
|
+
-------
|
|
378
|
+
dict[str, float]
|
|
379
|
+
Training metrics.
|
|
380
|
+
"""
|
|
381
|
+
# Convert validation_labels_numeric back to string labels for validation metrics
|
|
382
|
+
validation_labels = None
|
|
383
|
+
if validation_items is not None and validation_labels_numeric is not None:
|
|
384
|
+
validation_labels = [
|
|
385
|
+
self.category_names[label_idx]
|
|
386
|
+
for label_idx in validation_labels_numeric
|
|
387
|
+
]
|
|
388
|
+
|
|
389
|
+
# Use HuggingFace Trainer for fixed and random_intercepts modes
|
|
390
|
+
if self.config.mixed_effects.mode in ("fixed", "random_intercepts"):
|
|
391
|
+
metrics = self._train_with_huggingface_trainer(
|
|
392
|
+
items=items,
|
|
393
|
+
y_numeric=labels_numeric,
|
|
394
|
+
participant_ids=participant_ids,
|
|
395
|
+
validation_items=validation_items,
|
|
396
|
+
validation_labels=validation_labels,
|
|
397
|
+
)
|
|
398
|
+
else:
|
|
399
|
+
# Use custom loop for random_slopes mode
|
|
400
|
+
metrics = self._train_with_custom_loop(
|
|
401
|
+
items=items,
|
|
402
|
+
y_numeric=labels_numeric,
|
|
403
|
+
participant_ids=participant_ids,
|
|
404
|
+
validation_items=validation_items,
|
|
405
|
+
validation_labels=validation_labels,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Add validation accuracy if validation data provided and not already computed
|
|
409
|
+
if (
|
|
410
|
+
validation_items is not None
|
|
411
|
+
and validation_labels is not None
|
|
412
|
+
and "val_accuracy" not in metrics
|
|
413
|
+
):
|
|
414
|
+
# Validation with placeholder participant_ids for mixed effects
|
|
415
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
416
|
+
val_participant_ids = ["_fixed_"] * len(validation_items)
|
|
417
|
+
else:
|
|
418
|
+
val_participant_ids = ["_validation_"] * len(validation_items)
|
|
419
|
+
val_predictions = self._do_predict(validation_items, val_participant_ids)
|
|
420
|
+
val_pred_labels = [p.predicted_class for p in val_predictions]
|
|
421
|
+
val_acc = sum(
|
|
422
|
+
pred == true
|
|
423
|
+
for pred, true in zip(val_pred_labels, validation_labels, strict=True)
|
|
424
|
+
) / len(validation_labels)
|
|
425
|
+
metrics["val_accuracy"] = val_acc
|
|
426
|
+
|
|
427
|
+
return metrics
|
|
428
|
+
|
|
429
|
+
def _train_with_huggingface_trainer(
|
|
430
|
+
self,
|
|
431
|
+
items: list[Item],
|
|
432
|
+
y_numeric: list[int],
|
|
433
|
+
participant_ids: list[str],
|
|
434
|
+
validation_items: list[Item] | None,
|
|
435
|
+
validation_labels: list[str] | None,
|
|
436
|
+
) -> dict[str, float]:
|
|
437
|
+
"""Train using HuggingFace Trainer with mixed effects support.
|
|
438
|
+
|
|
439
|
+
Parameters
|
|
440
|
+
----------
|
|
441
|
+
items : list[Item]
|
|
442
|
+
Training items.
|
|
443
|
+
y_numeric : list[int]
|
|
444
|
+
Numeric labels (class indices).
|
|
445
|
+
participant_ids : list[str]
|
|
446
|
+
Participant IDs.
|
|
447
|
+
validation_items : list[Item] | None
|
|
448
|
+
Validation items.
|
|
449
|
+
validation_labels : list[str] | None
|
|
450
|
+
Validation labels.
|
|
451
|
+
|
|
452
|
+
Returns
|
|
453
|
+
-------
|
|
454
|
+
dict[str, float]
|
|
455
|
+
Training metrics.
|
|
456
|
+
"""
|
|
457
|
+
# Convert items to HuggingFace Dataset
|
|
458
|
+
train_dataset = items_to_dataset(
|
|
459
|
+
items=items,
|
|
460
|
+
labels=y_numeric,
|
|
461
|
+
participant_ids=participant_ids,
|
|
462
|
+
tokenizer=self.tokenizer,
|
|
463
|
+
max_length=self.config.max_length,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Create validation dataset if provided
|
|
467
|
+
eval_dataset = None
|
|
468
|
+
if validation_items is not None and validation_labels is not None:
|
|
469
|
+
label_to_idx = {label: idx for idx, label in enumerate(self.category_names)}
|
|
470
|
+
val_y_numeric = [label_to_idx[label] for label in validation_labels]
|
|
471
|
+
val_participant_ids = (
|
|
472
|
+
["_validation_"] * len(validation_items)
|
|
473
|
+
if self.config.mixed_effects.mode != "fixed"
|
|
474
|
+
else ["_fixed_"] * len(validation_items)
|
|
475
|
+
)
|
|
476
|
+
eval_dataset = items_to_dataset(
|
|
477
|
+
items=validation_items,
|
|
478
|
+
labels=val_y_numeric,
|
|
479
|
+
participant_ids=val_participant_ids,
|
|
480
|
+
tokenizer=self.tokenizer,
|
|
481
|
+
max_length=self.config.max_length,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# Create wrapper model for Trainer
|
|
485
|
+
wrapped_model = EncoderClassifierWrapper(
|
|
486
|
+
encoder=self.encoder, classifier_head=self.classifier_head
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
# Create data collator
|
|
490
|
+
data_collator = MixedEffectsDataCollator(tokenizer=self.tokenizer)
|
|
491
|
+
|
|
492
|
+
# Create metrics computation function
|
|
493
|
+
def compute_metrics_fn(eval_pred: object) -> dict[str, float]:
|
|
494
|
+
return compute_multiclass_metrics(eval_pred, num_labels=self.num_classes)
|
|
495
|
+
|
|
496
|
+
# Create training arguments with checkpointing
|
|
497
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
498
|
+
checkpoint_dir = Path(tmpdir) / "checkpoints"
|
|
499
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
500
|
+
|
|
501
|
+
training_args = TrainingArguments(
|
|
502
|
+
output_dir=str(checkpoint_dir),
|
|
503
|
+
num_train_epochs=self.config.num_epochs,
|
|
504
|
+
per_device_train_batch_size=self.config.batch_size,
|
|
505
|
+
per_device_eval_batch_size=self.config.batch_size,
|
|
506
|
+
learning_rate=self.config.learning_rate,
|
|
507
|
+
logging_steps=10,
|
|
508
|
+
eval_strategy="epoch" if eval_dataset is not None else "no",
|
|
509
|
+
save_strategy="epoch",
|
|
510
|
+
save_total_limit=1,
|
|
511
|
+
load_best_model_at_end=False,
|
|
512
|
+
report_to="none",
|
|
513
|
+
remove_unused_columns=False,
|
|
514
|
+
use_cpu=self.config.device == "cpu",
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
# Import here to avoid circular import
|
|
518
|
+
from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
|
|
519
|
+
MixedEffectsTrainer,
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
# Create trainer
|
|
523
|
+
trainer = MixedEffectsTrainer(
|
|
524
|
+
model=wrapped_model,
|
|
525
|
+
args=training_args,
|
|
526
|
+
train_dataset=train_dataset,
|
|
527
|
+
eval_dataset=eval_dataset,
|
|
528
|
+
data_collator=data_collator,
|
|
529
|
+
tokenizer=self.tokenizer,
|
|
530
|
+
random_effects_manager=self.random_effects,
|
|
531
|
+
compute_metrics=compute_metrics_fn,
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
# Train
|
|
535
|
+
train_result = trainer.train()
|
|
536
|
+
|
|
537
|
+
# Get training metrics
|
|
538
|
+
train_metrics = trainer.evaluate(eval_dataset=train_dataset)
|
|
539
|
+
metrics: dict[str, float] = {
|
|
540
|
+
"train_loss": float(train_result.training_loss),
|
|
541
|
+
"train_accuracy": train_metrics.get("eval_accuracy", 0.0),
|
|
542
|
+
"train_precision": train_metrics.get("eval_precision", 0.0),
|
|
543
|
+
"train_recall": train_metrics.get("eval_recall", 0.0),
|
|
544
|
+
"train_f1": train_metrics.get("eval_f1", 0.0),
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
# Get validation metrics if eval_dataset was provided
|
|
548
|
+
if eval_dataset is not None:
|
|
549
|
+
val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
|
550
|
+
metrics.update(
|
|
551
|
+
{
|
|
552
|
+
"val_accuracy": val_metrics.get("eval_accuracy", 0.0),
|
|
553
|
+
"val_precision": val_metrics.get("eval_precision", 0.0),
|
|
554
|
+
"val_recall": val_metrics.get("eval_recall", 0.0),
|
|
555
|
+
"val_f1": val_metrics.get("eval_f1", 0.0),
|
|
556
|
+
}
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
# Estimate variance components
|
|
560
|
+
if self.config.mixed_effects.estimate_variance_components:
|
|
561
|
+
var_comps = self.random_effects.estimate_variance_components()
|
|
562
|
+
if var_comps:
|
|
563
|
+
var_comp = var_comps.get("mu") or var_comps.get("slopes")
|
|
564
|
+
if var_comp:
|
|
565
|
+
self.variance_history.append(var_comp)
|
|
566
|
+
metrics["participant_variance"] = var_comp.variance
|
|
567
|
+
metrics["n_participants"] = var_comp.n_groups
|
|
568
|
+
|
|
569
|
+
self._is_fitted = True
|
|
570
|
+
|
|
571
|
+
return metrics
|
|
572
|
+
|
|
573
|
+
def _train_with_custom_loop(
|
|
574
|
+
self,
|
|
575
|
+
items: list[Item],
|
|
576
|
+
y_numeric: list[int],
|
|
577
|
+
participant_ids: list[str],
|
|
578
|
+
validation_items: list[Item] | None,
|
|
579
|
+
validation_labels: list[str] | None,
|
|
580
|
+
) -> dict[str, float]:
|
|
581
|
+
"""Train using custom training loop (for random_slopes mode).
|
|
582
|
+
|
|
583
|
+
Parameters
|
|
584
|
+
----------
|
|
585
|
+
items : list[Item]
|
|
586
|
+
Training items.
|
|
587
|
+
y_numeric : list[int]
|
|
588
|
+
Numeric labels (class indices).
|
|
589
|
+
participant_ids : list[str]
|
|
590
|
+
Participant IDs.
|
|
591
|
+
validation_items : list[Item] | None
|
|
592
|
+
Validation items.
|
|
593
|
+
validation_labels : list[str] | None
|
|
594
|
+
Validation labels.
|
|
595
|
+
|
|
596
|
+
Returns
|
|
597
|
+
-------
|
|
598
|
+
dict[str, float]
|
|
599
|
+
Training metrics.
|
|
600
|
+
"""
|
|
601
|
+
# Convert to tensor
|
|
602
|
+
y = torch.tensor(y_numeric, dtype=torch.long, device=self.config.device)
|
|
603
|
+
|
|
604
|
+
# Build optimizer parameters
|
|
605
|
+
params_to_optimize = list(self.encoder.parameters()) + list(
|
|
606
|
+
self.classifier_head.parameters()
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
# Add random effects parameters (for random_slopes)
|
|
610
|
+
if self.config.mixed_effects.mode == "random_slopes":
|
|
611
|
+
for head in self.random_effects.slopes.values():
|
|
612
|
+
params_to_optimize.extend(head.parameters())
|
|
613
|
+
|
|
614
|
+
optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
|
|
615
|
+
criterion = nn.CrossEntropyLoss()
|
|
616
|
+
|
|
617
|
+
self.encoder.train()
|
|
618
|
+
self.classifier_head.train()
|
|
619
|
+
|
|
620
|
+
for _epoch in range(self.config.num_epochs):
|
|
621
|
+
n_batches = (
|
|
622
|
+
len(items) + self.config.batch_size - 1
|
|
623
|
+
) // self.config.batch_size
|
|
624
|
+
epoch_loss = 0.0
|
|
625
|
+
epoch_correct = 0
|
|
626
|
+
|
|
627
|
+
for i in range(n_batches):
|
|
628
|
+
start_idx = i * self.config.batch_size
|
|
629
|
+
end_idx = min(start_idx + self.config.batch_size, len(items))
|
|
630
|
+
|
|
631
|
+
batch_items = items[start_idx:end_idx]
|
|
632
|
+
batch_labels = y[start_idx:end_idx]
|
|
633
|
+
batch_participant_ids = participant_ids[start_idx:end_idx]
|
|
634
|
+
|
|
635
|
+
embeddings = self._prepare_inputs(batch_items)
|
|
636
|
+
|
|
637
|
+
# Forward pass depends on mixed effects mode
|
|
638
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
639
|
+
# Standard forward pass
|
|
640
|
+
logits = self.classifier_head(embeddings)
|
|
641
|
+
|
|
642
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
643
|
+
# Fixed head + per-participant bias
|
|
644
|
+
logits = self.classifier_head(embeddings)
|
|
645
|
+
for j, pid in enumerate(batch_participant_ids):
|
|
646
|
+
bias = self.random_effects.get_intercepts(
|
|
647
|
+
pid,
|
|
648
|
+
n_classes=self.num_classes,
|
|
649
|
+
param_name="mu",
|
|
650
|
+
create_if_missing=True,
|
|
651
|
+
)
|
|
652
|
+
logits[j] = logits[j] + bias
|
|
653
|
+
|
|
654
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
655
|
+
# Per-participant head
|
|
656
|
+
logits_list = []
|
|
657
|
+
for j, pid in enumerate(batch_participant_ids):
|
|
658
|
+
participant_head = self.random_effects.get_slopes(
|
|
659
|
+
pid,
|
|
660
|
+
fixed_head=self.classifier_head,
|
|
661
|
+
create_if_missing=True,
|
|
662
|
+
)
|
|
663
|
+
logits_j = participant_head(embeddings[j : j + 1])
|
|
664
|
+
logits_list.append(logits_j)
|
|
665
|
+
logits = torch.cat(logits_list, dim=0)
|
|
666
|
+
|
|
667
|
+
# Data loss + prior regularization
|
|
668
|
+
loss_ce = criterion(logits, batch_labels)
|
|
669
|
+
loss_prior = self.random_effects.compute_prior_loss()
|
|
670
|
+
loss = loss_ce + loss_prior
|
|
671
|
+
|
|
672
|
+
optimizer.zero_grad()
|
|
673
|
+
loss.backward()
|
|
674
|
+
optimizer.step()
|
|
675
|
+
|
|
676
|
+
epoch_loss += loss.item()
|
|
677
|
+
predictions = torch.argmax(logits, dim=1)
|
|
678
|
+
epoch_correct += (predictions == batch_labels).sum().item()
|
|
679
|
+
|
|
680
|
+
epoch_acc = epoch_correct / len(items)
|
|
681
|
+
epoch_loss = epoch_loss / n_batches
|
|
682
|
+
|
|
683
|
+
self._is_fitted = True
|
|
684
|
+
|
|
685
|
+
metrics: dict[str, float] = {
|
|
686
|
+
"train_accuracy": epoch_acc,
|
|
687
|
+
"train_loss": epoch_loss,
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
# Estimate variance components
|
|
691
|
+
if self.config.mixed_effects.estimate_variance_components:
|
|
692
|
+
var_comps = self.random_effects.estimate_variance_components()
|
|
693
|
+
if var_comps:
|
|
694
|
+
var_comp = var_comps.get("mu") or var_comps.get("slopes")
|
|
695
|
+
if var_comp:
|
|
696
|
+
self.variance_history.append(var_comp)
|
|
697
|
+
metrics["participant_variance"] = var_comp.variance
|
|
698
|
+
metrics["n_participants"] = var_comp.n_groups
|
|
699
|
+
|
|
700
|
+
if validation_items is not None and validation_labels is not None:
|
|
701
|
+
self._validate_labels(validation_labels)
|
|
702
|
+
|
|
703
|
+
if len(validation_items) != len(validation_labels):
|
|
704
|
+
raise ValueError(
|
|
705
|
+
f"Number of validation items ({len(validation_items)}) "
|
|
706
|
+
f"must match number of validation labels ({len(validation_labels)})"
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
# Validation with placeholder participant_ids for mixed effects
|
|
710
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
711
|
+
val_predictions = self.predict(validation_items, participant_ids=None)
|
|
712
|
+
else:
|
|
713
|
+
val_participant_ids = ["_validation_"] * len(validation_items)
|
|
714
|
+
val_predictions = self.predict(
|
|
715
|
+
validation_items, participant_ids=val_participant_ids
|
|
716
|
+
)
|
|
717
|
+
val_pred_labels = [p.predicted_class for p in val_predictions]
|
|
718
|
+
val_acc = sum(
|
|
719
|
+
pred == true
|
|
720
|
+
for pred, true in zip(val_pred_labels, validation_labels, strict=True)
|
|
721
|
+
) / len(validation_labels)
|
|
722
|
+
metrics["val_accuracy"] = val_acc
|
|
723
|
+
|
|
724
|
+
return metrics
|
|
725
|
+
|
|
726
|
+
def _do_predict(
|
|
727
|
+
self, items: list[Item], participant_ids: list[str]
|
|
728
|
+
) -> list[ModelPrediction]:
|
|
729
|
+
"""Perform categorical model prediction.
|
|
730
|
+
|
|
731
|
+
Parameters
|
|
732
|
+
----------
|
|
733
|
+
items : list[Item]
|
|
734
|
+
Items to predict.
|
|
735
|
+
participant_ids : list[str]
|
|
736
|
+
Normalized participant IDs.
|
|
737
|
+
|
|
738
|
+
Returns
|
|
739
|
+
-------
|
|
740
|
+
list[ModelPrediction]
|
|
741
|
+
Predictions.
|
|
742
|
+
"""
|
|
743
|
+
self.encoder.eval()
|
|
744
|
+
self.classifier_head.eval()
|
|
745
|
+
|
|
746
|
+
with torch.no_grad():
|
|
747
|
+
embeddings = self._prepare_inputs(items)
|
|
748
|
+
|
|
749
|
+
# Forward pass depends on mixed effects mode
|
|
750
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
751
|
+
logits = self.classifier_head(embeddings)
|
|
752
|
+
|
|
753
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
754
|
+
logits = self.classifier_head(embeddings)
|
|
755
|
+
for i, pid in enumerate(participant_ids):
|
|
756
|
+
# Unknown participants: use prior mean (zero bias)
|
|
757
|
+
bias = self.random_effects.get_intercepts(
|
|
758
|
+
pid,
|
|
759
|
+
n_classes=self.num_classes,
|
|
760
|
+
param_name="mu",
|
|
761
|
+
create_if_missing=False,
|
|
762
|
+
)
|
|
763
|
+
logits[i] = logits[i] + bias
|
|
764
|
+
|
|
765
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
766
|
+
logits_list = []
|
|
767
|
+
for i, pid in enumerate(participant_ids):
|
|
768
|
+
# Unknown participants: use fixed head
|
|
769
|
+
participant_head = self.random_effects.get_slopes(
|
|
770
|
+
pid, fixed_head=self.classifier_head, create_if_missing=False
|
|
771
|
+
)
|
|
772
|
+
logits_i = participant_head(embeddings[i : i + 1])
|
|
773
|
+
logits_list.append(logits_i)
|
|
774
|
+
logits = torch.cat(logits_list, dim=0)
|
|
775
|
+
|
|
776
|
+
proba = torch.softmax(logits, dim=1).cpu().numpy()
|
|
777
|
+
pred_classes = torch.argmax(logits, dim=1).cpu().numpy()
|
|
778
|
+
|
|
779
|
+
predictions = []
|
|
780
|
+
for i, item in enumerate(items):
|
|
781
|
+
pred_label = self.category_names[pred_classes[i]]
|
|
782
|
+
prob_dict = {
|
|
783
|
+
cat: float(proba[i, idx]) for idx, cat in enumerate(self.category_names)
|
|
784
|
+
}
|
|
785
|
+
predictions.append(
|
|
786
|
+
ModelPrediction(
|
|
787
|
+
item_id=str(item.id),
|
|
788
|
+
probabilities=prob_dict,
|
|
789
|
+
predicted_class=pred_label,
|
|
790
|
+
confidence=float(proba[i, pred_classes[i]]),
|
|
791
|
+
)
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
return predictions
|
|
795
|
+
|
|
796
|
+
def _do_predict_proba(
|
|
797
|
+
self, items: list[Item], participant_ids: list[str]
|
|
798
|
+
) -> np.ndarray:
|
|
799
|
+
"""Perform categorical model probability prediction.
|
|
800
|
+
|
|
801
|
+
Parameters
|
|
802
|
+
----------
|
|
803
|
+
items : list[Item]
|
|
804
|
+
Items to predict.
|
|
805
|
+
participant_ids : list[str]
|
|
806
|
+
Normalized participant IDs.
|
|
807
|
+
|
|
808
|
+
Returns
|
|
809
|
+
-------
|
|
810
|
+
np.ndarray
|
|
811
|
+
Probability array of shape (n_items, n_classes).
|
|
812
|
+
"""
|
|
813
|
+
self.encoder.eval()
|
|
814
|
+
self.classifier_head.eval()
|
|
815
|
+
|
|
816
|
+
with torch.no_grad():
|
|
817
|
+
embeddings = self._prepare_inputs(items)
|
|
818
|
+
|
|
819
|
+
# Forward pass depends on mixed effects mode
|
|
820
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
821
|
+
logits = self.classifier_head(embeddings)
|
|
822
|
+
|
|
823
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
824
|
+
logits = self.classifier_head(embeddings)
|
|
825
|
+
for i, pid in enumerate(participant_ids):
|
|
826
|
+
bias = self.random_effects.get_intercepts(
|
|
827
|
+
pid,
|
|
828
|
+
n_classes=self.num_classes,
|
|
829
|
+
param_name="mu",
|
|
830
|
+
create_if_missing=False,
|
|
831
|
+
)
|
|
832
|
+
logits[i] = logits[i] + bias
|
|
833
|
+
|
|
834
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
835
|
+
logits_list = []
|
|
836
|
+
for i, pid in enumerate(participant_ids):
|
|
837
|
+
participant_head = self.random_effects.get_slopes(
|
|
838
|
+
pid, fixed_head=self.classifier_head, create_if_missing=False
|
|
839
|
+
)
|
|
840
|
+
logits_i = participant_head(embeddings[i : i + 1])
|
|
841
|
+
logits_list.append(logits_i)
|
|
842
|
+
logits = torch.cat(logits_list, dim=0)
|
|
843
|
+
|
|
844
|
+
proba = torch.softmax(logits, dim=1).cpu().numpy()
|
|
845
|
+
|
|
846
|
+
return proba
|
|
847
|
+
|
|
848
|
+
def _get_save_state(self) -> dict[str, object]:
|
|
849
|
+
"""Get model-specific state to save.
|
|
850
|
+
|
|
851
|
+
Returns
|
|
852
|
+
-------
|
|
853
|
+
dict[str, object]
|
|
854
|
+
State dictionary.
|
|
855
|
+
"""
|
|
856
|
+
return {
|
|
857
|
+
"num_classes": self.num_classes,
|
|
858
|
+
"category_names": self.category_names,
|
|
859
|
+
}
|
|
860
|
+
|
|
861
|
+
def _save_model_components(self, save_path: Path) -> None:
|
|
862
|
+
"""Save model-specific components.
|
|
863
|
+
|
|
864
|
+
Parameters
|
|
865
|
+
----------
|
|
866
|
+
save_path : Path
|
|
867
|
+
Directory to save to.
|
|
868
|
+
"""
|
|
869
|
+
self.encoder.save_pretrained(save_path / "encoder")
|
|
870
|
+
self.tokenizer.save_pretrained(save_path / "encoder")
|
|
871
|
+
|
|
872
|
+
torch.save(
|
|
873
|
+
self.classifier_head.state_dict(),
|
|
874
|
+
save_path / "classifier_head.pt",
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
def _restore_training_state(self, config_dict: dict[str, object]) -> None:
|
|
878
|
+
"""Restore model-specific training state.
|
|
879
|
+
|
|
880
|
+
Parameters
|
|
881
|
+
----------
|
|
882
|
+
config_dict : dict[str, object]
|
|
883
|
+
Configuration dictionary with training state.
|
|
884
|
+
"""
|
|
885
|
+
self.num_classes = config_dict.pop("num_classes")
|
|
886
|
+
self.category_names = config_dict.pop("category_names")
|
|
887
|
+
|
|
888
|
+
def _load_model_components(self, load_path: Path) -> None:
|
|
889
|
+
"""Load model-specific components.
|
|
890
|
+
|
|
891
|
+
Parameters
|
|
892
|
+
----------
|
|
893
|
+
load_path : Path
|
|
894
|
+
Directory to load from.
|
|
895
|
+
"""
|
|
896
|
+
# Load config.json to reconstruct config
|
|
897
|
+
with open(load_path / "config.json") as f:
|
|
898
|
+
import json # noqa: PLC0415
|
|
899
|
+
|
|
900
|
+
config_dict = json.load(f)
|
|
901
|
+
|
|
902
|
+
# Reconstruct MixedEffectsConfig if needed
|
|
903
|
+
if "mixed_effects" in config_dict and isinstance(
|
|
904
|
+
config_dict["mixed_effects"], dict
|
|
905
|
+
):
|
|
906
|
+
from bead.active_learning.config import MixedEffectsConfig # noqa: PLC0415
|
|
907
|
+
|
|
908
|
+
config_dict["mixed_effects"] = MixedEffectsConfig(
|
|
909
|
+
**config_dict["mixed_effects"]
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
self.config = CategoricalModelConfig(**config_dict)
|
|
913
|
+
|
|
914
|
+
self.encoder = AutoModel.from_pretrained(load_path / "encoder")
|
|
915
|
+
self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
|
|
916
|
+
|
|
917
|
+
self._initialize_classifier(self.num_classes)
|
|
918
|
+
self.classifier_head.load_state_dict(
|
|
919
|
+
torch.load(
|
|
920
|
+
load_path / "classifier_head.pt", map_location=self.config.device
|
|
921
|
+
)
|
|
922
|
+
)
|
|
923
|
+
self.classifier_head.to(self.config.device)
|
|
924
|
+
|
|
925
|
+
def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
|
|
926
|
+
"""Get fixed head for random effects loading.
|
|
927
|
+
|
|
928
|
+
Returns
|
|
929
|
+
-------
|
|
930
|
+
nn.Module | None
|
|
931
|
+
Fixed head module.
|
|
932
|
+
"""
|
|
933
|
+
return self.classifier_head
|
|
934
|
+
|
|
935
|
+
def _get_n_classes_for_random_effects(self) -> int:
|
|
936
|
+
"""Get number of classes for random effects initialization.
|
|
937
|
+
|
|
938
|
+
Returns
|
|
939
|
+
-------
|
|
940
|
+
int
|
|
941
|
+
Number of classes.
|
|
942
|
+
"""
|
|
943
|
+
return self.num_classes
|