bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,795 @@
|
|
|
1
|
+
"""Multi-select model for selecting multiple options.
|
|
2
|
+
|
|
3
|
+
Expected architecture: Multi-label classification with sigmoid output per option.
|
|
4
|
+
Each option can be independently selected or not selected.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
from transformers import AutoModel, AutoTokenizer
|
|
16
|
+
|
|
17
|
+
from bead.active_learning.config import MixedEffectsConfig, VarianceComponents
|
|
18
|
+
from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
|
|
19
|
+
from bead.active_learning.models.random_effects import RandomEffectsManager
|
|
20
|
+
from bead.config.active_learning import MultiSelectModelConfig
|
|
21
|
+
from bead.items.item import Item
|
|
22
|
+
from bead.items.item_template import ItemTemplate, TaskType
|
|
23
|
+
|
|
24
|
+
__all__ = ["MultiSelectModel"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MultiSelectModel(ActiveLearningModel):
|
|
28
|
+
"""Model for multi_select tasks with N selectable options.
|
|
29
|
+
|
|
30
|
+
Uses multi-label classification where each option can be independently
|
|
31
|
+
selected or not selected. Applies sigmoid activation to each option's
|
|
32
|
+
logit and uses BCEWithLogitsLoss for training.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
config : MultiSelectModelConfig
|
|
37
|
+
Configuration object containing all model parameters.
|
|
38
|
+
|
|
39
|
+
Attributes
|
|
40
|
+
----------
|
|
41
|
+
config : MultiSelectModelConfig
|
|
42
|
+
Model configuration.
|
|
43
|
+
tokenizer : AutoTokenizer
|
|
44
|
+
Transformer tokenizer.
|
|
45
|
+
encoder : AutoModel
|
|
46
|
+
Transformer encoder model.
|
|
47
|
+
classifier_head : nn.Sequential
|
|
48
|
+
Classification head (fixed effects head) - outputs N logits.
|
|
49
|
+
num_options : int | None
|
|
50
|
+
Number of selectable options (inferred from training data).
|
|
51
|
+
option_names : list[str] | None
|
|
52
|
+
Option names (e.g., ["option_a", "option_b", "option_c"]).
|
|
53
|
+
random_effects : RandomEffectsManager
|
|
54
|
+
Manager for participant-level random effects.
|
|
55
|
+
variance_history : list[VarianceComponents]
|
|
56
|
+
Variance component estimates over training (for diagnostics).
|
|
57
|
+
_is_fitted : bool
|
|
58
|
+
Whether model has been trained.
|
|
59
|
+
|
|
60
|
+
Examples
|
|
61
|
+
--------
|
|
62
|
+
>>> from uuid import uuid4
|
|
63
|
+
>>> from bead.items.item import Item
|
|
64
|
+
>>> from bead.config.active_learning import MultiSelectModelConfig
|
|
65
|
+
>>> items = [
|
|
66
|
+
... Item(
|
|
67
|
+
... item_template_id=uuid4(),
|
|
68
|
+
... rendered_elements={
|
|
69
|
+
... "option_a": "First option",
|
|
70
|
+
... "option_b": "Second option",
|
|
71
|
+
... "option_c": "Third option"
|
|
72
|
+
... }
|
|
73
|
+
... )
|
|
74
|
+
... for _ in range(10)
|
|
75
|
+
... ]
|
|
76
|
+
>>> # Labels as lists of selected options
|
|
77
|
+
>>> labels_list = [["option_a", "option_b"], ["option_c"], ["option_a"]]
|
|
78
|
+
>>> labels = labels_list * 3 + [["option_b"]]
|
|
79
|
+
>>> config = MultiSelectModelConfig( # doctest: +SKIP
|
|
80
|
+
... num_epochs=1, batch_size=2, device="cpu"
|
|
81
|
+
... )
|
|
82
|
+
>>> model = MultiSelectModel(config=config) # doctest: +SKIP
|
|
83
|
+
>>> # Convert labels to serialized format for train()
|
|
84
|
+
>>> label_strs = [json.dumps(sorted(lbls)) for lbls in labels] # doctest: +SKIP
|
|
85
|
+
>>> metrics = model.train(items, label_strs, participant_ids=None) # doctest: +SKIP
|
|
86
|
+
|
|
87
|
+
Notes
|
|
88
|
+
-----
|
|
89
|
+
This model uses BCEWithLogitsLoss (not CrossEntropyLoss) and applies
|
|
90
|
+
sigmoid activation to get independent probabilities for each option.
|
|
91
|
+
Random intercepts are bias vectors (one per option) that shift logits
|
|
92
|
+
independently for each participant.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
config: MultiSelectModelConfig | None = None,
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Initialize multi-select model.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
config : MultiSelectModelConfig | None
|
|
104
|
+
Configuration object. If None, uses default configuration.
|
|
105
|
+
"""
|
|
106
|
+
self.config = config or MultiSelectModelConfig()
|
|
107
|
+
|
|
108
|
+
# Validate mixed_effects configuration
|
|
109
|
+
super().__init__(self.config)
|
|
110
|
+
|
|
111
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
|
112
|
+
self.encoder = AutoModel.from_pretrained(self.config.model_name)
|
|
113
|
+
|
|
114
|
+
self.num_options: int | None = None
|
|
115
|
+
self.option_names: list[str] | None = None
|
|
116
|
+
self.classifier_head: nn.Sequential | None = None
|
|
117
|
+
self._is_fitted = False
|
|
118
|
+
|
|
119
|
+
# Initialize random effects manager
|
|
120
|
+
self.random_effects: RandomEffectsManager | None = None
|
|
121
|
+
self.variance_history: list[VarianceComponents] = []
|
|
122
|
+
|
|
123
|
+
self.encoder.to(self.config.device)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def supported_task_types(self) -> list[TaskType]:
|
|
127
|
+
"""Get supported task types.
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
list[TaskType]
|
|
132
|
+
List containing "multi_select".
|
|
133
|
+
"""
|
|
134
|
+
return ["multi_select"]
|
|
135
|
+
|
|
136
|
+
def validate_item_compatibility(
|
|
137
|
+
self, item: Item, item_template: ItemTemplate
|
|
138
|
+
) -> None:
|
|
139
|
+
"""Validate item is compatible with multi-select model.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
item : Item
|
|
144
|
+
Item to validate.
|
|
145
|
+
item_template : ItemTemplate
|
|
146
|
+
Template the item was constructed from.
|
|
147
|
+
|
|
148
|
+
Raises
|
|
149
|
+
------
|
|
150
|
+
ValueError
|
|
151
|
+
If task_type is not "multi_select".
|
|
152
|
+
"""
|
|
153
|
+
if item_template.task_type != "multi_select":
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Expected task_type 'multi_select', got '{item_template.task_type}'"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def _initialize_classifier(self, num_options: int) -> None:
|
|
159
|
+
"""Initialize classification head for given number of options.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
num_options : int
|
|
164
|
+
Number of selectable options (output units).
|
|
165
|
+
"""
|
|
166
|
+
hidden_size = self.encoder.config.hidden_size
|
|
167
|
+
|
|
168
|
+
if self.config.encoder_mode == "dual_encoder":
|
|
169
|
+
input_size = hidden_size * num_options
|
|
170
|
+
else:
|
|
171
|
+
input_size = hidden_size
|
|
172
|
+
|
|
173
|
+
self.classifier_head = nn.Sequential(
|
|
174
|
+
nn.Linear(input_size, 256),
|
|
175
|
+
nn.ReLU(),
|
|
176
|
+
nn.Dropout(0.1),
|
|
177
|
+
nn.Linear(256, num_options), # N independent outputs
|
|
178
|
+
)
|
|
179
|
+
self.classifier_head.to(self.config.device)
|
|
180
|
+
|
|
181
|
+
def _encode_single(self, texts: list[str]) -> torch.Tensor:
|
|
182
|
+
"""Encode texts using single encoder strategy.
|
|
183
|
+
|
|
184
|
+
Concatenates all option texts with [SEP] tokens and encodes once.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
texts : list[str]
|
|
189
|
+
List of concatenated option texts for each item.
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
torch.Tensor
|
|
194
|
+
Encoded representations of shape (batch_size, hidden_size).
|
|
195
|
+
"""
|
|
196
|
+
encodings = self.tokenizer(
|
|
197
|
+
texts,
|
|
198
|
+
padding=True,
|
|
199
|
+
truncation=True,
|
|
200
|
+
max_length=self.config.max_length,
|
|
201
|
+
return_tensors="pt",
|
|
202
|
+
)
|
|
203
|
+
encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
|
|
204
|
+
|
|
205
|
+
outputs = self.encoder(**encodings)
|
|
206
|
+
return outputs.last_hidden_state[:, 0, :]
|
|
207
|
+
|
|
208
|
+
def _encode_dual(self, options_per_item: list[list[str]]) -> torch.Tensor:
|
|
209
|
+
"""Encode texts using dual encoder strategy.
|
|
210
|
+
|
|
211
|
+
Encodes each option separately and concatenates embeddings.
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
options_per_item : list[list[str]]
|
|
216
|
+
List of option lists. Each inner list contains option texts for one item.
|
|
217
|
+
|
|
218
|
+
Returns
|
|
219
|
+
-------
|
|
220
|
+
torch.Tensor
|
|
221
|
+
Concatenated encodings of shape (batch_size, hidden_size * num_options).
|
|
222
|
+
"""
|
|
223
|
+
all_embeddings = []
|
|
224
|
+
|
|
225
|
+
for options in options_per_item:
|
|
226
|
+
option_embeddings = []
|
|
227
|
+
for option_text in options:
|
|
228
|
+
encodings = self.tokenizer(
|
|
229
|
+
[option_text],
|
|
230
|
+
padding=True,
|
|
231
|
+
truncation=True,
|
|
232
|
+
max_length=self.config.max_length,
|
|
233
|
+
return_tensors="pt",
|
|
234
|
+
)
|
|
235
|
+
encodings = {k: v.to(self.config.device) for k, v in encodings.items()}
|
|
236
|
+
|
|
237
|
+
outputs = self.encoder(**encodings)
|
|
238
|
+
cls_embedding = outputs.last_hidden_state[0, 0, :]
|
|
239
|
+
option_embeddings.append(cls_embedding)
|
|
240
|
+
|
|
241
|
+
concatenated = torch.cat(option_embeddings, dim=0)
|
|
242
|
+
all_embeddings.append(concatenated)
|
|
243
|
+
|
|
244
|
+
return torch.stack(all_embeddings)
|
|
245
|
+
|
|
246
|
+
def _prepare_inputs(self, items: list[Item]) -> torch.Tensor:
|
|
247
|
+
"""Prepare inputs for encoding based on encoder mode.
|
|
248
|
+
|
|
249
|
+
For multi-select tasks, uses all options from rendered_elements.
|
|
250
|
+
|
|
251
|
+
Parameters
|
|
252
|
+
----------
|
|
253
|
+
items : list[Item]
|
|
254
|
+
Items to encode.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
torch.Tensor
|
|
259
|
+
Encoded representations.
|
|
260
|
+
"""
|
|
261
|
+
if self.option_names is None:
|
|
262
|
+
raise ValueError("Model not initialized. Call train() first.")
|
|
263
|
+
|
|
264
|
+
if self.config.encoder_mode == "single_encoder":
|
|
265
|
+
texts = []
|
|
266
|
+
for item in items:
|
|
267
|
+
option_texts = [
|
|
268
|
+
item.rendered_elements.get(opt, "") for opt in self.option_names
|
|
269
|
+
]
|
|
270
|
+
concatenated = " [SEP] ".join(option_texts)
|
|
271
|
+
texts.append(concatenated)
|
|
272
|
+
return self._encode_single(texts)
|
|
273
|
+
else:
|
|
274
|
+
options_per_item = []
|
|
275
|
+
for item in items:
|
|
276
|
+
option_texts = [
|
|
277
|
+
item.rendered_elements.get(opt, "") for opt in self.option_names
|
|
278
|
+
]
|
|
279
|
+
options_per_item.append(option_texts)
|
|
280
|
+
return self._encode_dual(options_per_item)
|
|
281
|
+
|
|
282
|
+
def _parse_multi_select_labels(self, label_str: str) -> list[str]:
|
|
283
|
+
"""Parse multi-select label from JSON string.
|
|
284
|
+
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
label_str : str
|
|
288
|
+
JSON-serialized list of selected options.
|
|
289
|
+
|
|
290
|
+
Returns
|
|
291
|
+
-------
|
|
292
|
+
list[str]
|
|
293
|
+
List of selected option names.
|
|
294
|
+
"""
|
|
295
|
+
try:
|
|
296
|
+
selected = json.loads(label_str)
|
|
297
|
+
if not isinstance(selected, list):
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"Label must be JSON list of option names, got {type(selected)}"
|
|
300
|
+
)
|
|
301
|
+
return selected
|
|
302
|
+
except json.JSONDecodeError as e:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"Label must be valid JSON list of selected options. "
|
|
305
|
+
f"Got: {label_str!r}. Error: {e}"
|
|
306
|
+
) from e
|
|
307
|
+
|
|
308
|
+
def _prepare_training_data(
|
|
309
|
+
self,
|
|
310
|
+
items: list[Item],
|
|
311
|
+
labels: list[str],
|
|
312
|
+
participant_ids: list[str],
|
|
313
|
+
validation_items: list[Item] | None,
|
|
314
|
+
validation_labels: list[str] | None,
|
|
315
|
+
) -> tuple[
|
|
316
|
+
list[Item], torch.Tensor, list[str], list[Item] | None, torch.Tensor | None
|
|
317
|
+
]:
|
|
318
|
+
"""Prepare training data for multi-select model.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
items : list[Item]
|
|
323
|
+
Training items.
|
|
324
|
+
labels : list[str]
|
|
325
|
+
Training labels (JSON strings of selected options).
|
|
326
|
+
participant_ids : list[str]
|
|
327
|
+
Normalized participant IDs.
|
|
328
|
+
validation_items : list[Item] | None
|
|
329
|
+
Validation items.
|
|
330
|
+
validation_labels : list[str] | None
|
|
331
|
+
Validation labels.
|
|
332
|
+
|
|
333
|
+
Returns
|
|
334
|
+
-------
|
|
335
|
+
tuple
|
|
336
|
+
Prepared items, labels, participant_ids, val items, val labels.
|
|
337
|
+
"""
|
|
338
|
+
if not items:
|
|
339
|
+
raise ValueError("Cannot train with empty items list")
|
|
340
|
+
|
|
341
|
+
# Infer option names from first item
|
|
342
|
+
self.option_names = sorted(items[0].rendered_elements.keys())
|
|
343
|
+
self.num_options = len(self.option_names)
|
|
344
|
+
option_to_idx = {opt: idx for idx, opt in enumerate(self.option_names)}
|
|
345
|
+
|
|
346
|
+
# Parse labels and convert to binary matrix
|
|
347
|
+
y = torch.zeros(
|
|
348
|
+
(len(items), self.num_options), dtype=torch.float, device=self.config.device
|
|
349
|
+
)
|
|
350
|
+
for i, label_str in enumerate(labels):
|
|
351
|
+
selected_options = self._parse_multi_select_labels(label_str)
|
|
352
|
+
for opt in selected_options:
|
|
353
|
+
if opt not in option_to_idx:
|
|
354
|
+
raise ValueError(
|
|
355
|
+
f"Invalid option {opt!r} in label. "
|
|
356
|
+
f"Valid options: {self.option_names}"
|
|
357
|
+
)
|
|
358
|
+
y[i, option_to_idx[opt]] = 1.0
|
|
359
|
+
|
|
360
|
+
self._initialize_classifier(self.num_options)
|
|
361
|
+
|
|
362
|
+
# Convert validation labels if provided
|
|
363
|
+
val_y = None
|
|
364
|
+
if validation_items is not None and validation_labels is not None:
|
|
365
|
+
if len(validation_items) != len(validation_labels):
|
|
366
|
+
raise ValueError(
|
|
367
|
+
f"Number of validation items ({len(validation_items)}) "
|
|
368
|
+
f"must match number of validation labels ({len(validation_labels)})"
|
|
369
|
+
)
|
|
370
|
+
val_y = torch.zeros(
|
|
371
|
+
(len(validation_items), self.num_options),
|
|
372
|
+
dtype=torch.float,
|
|
373
|
+
device=self.config.device,
|
|
374
|
+
)
|
|
375
|
+
for i, label_str in enumerate(validation_labels):
|
|
376
|
+
selected_options = self._parse_multi_select_labels(label_str)
|
|
377
|
+
for opt in selected_options:
|
|
378
|
+
if opt not in option_to_idx:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
f"Invalid option {opt!r} in validation label. "
|
|
381
|
+
f"Valid options: {self.option_names}"
|
|
382
|
+
)
|
|
383
|
+
val_y[i, option_to_idx[opt]] = 1.0
|
|
384
|
+
|
|
385
|
+
return items, y, participant_ids, validation_items, val_y
|
|
386
|
+
|
|
387
|
+
def _initialize_random_effects(self, n_classes: int) -> None:
|
|
388
|
+
"""Initialize random effects manager.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
n_classes : int
|
|
393
|
+
Number of classes (num_options for multi-select).
|
|
394
|
+
"""
|
|
395
|
+
self.random_effects = RandomEffectsManager(
|
|
396
|
+
self.config.mixed_effects, n_classes=n_classes
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
def _do_training(
|
|
400
|
+
self,
|
|
401
|
+
items: list[Item],
|
|
402
|
+
labels_numeric: torch.Tensor,
|
|
403
|
+
participant_ids: list[str],
|
|
404
|
+
validation_items: list[Item] | None,
|
|
405
|
+
validation_labels_numeric: torch.Tensor | None,
|
|
406
|
+
) -> dict[str, float]:
|
|
407
|
+
"""Perform multi-select model training.
|
|
408
|
+
|
|
409
|
+
Parameters
|
|
410
|
+
----------
|
|
411
|
+
items : list[Item]
|
|
412
|
+
Training items.
|
|
413
|
+
labels_numeric : torch.Tensor
|
|
414
|
+
Binary label tensor of shape (n_items, n_options).
|
|
415
|
+
participant_ids : list[str]
|
|
416
|
+
Participant IDs.
|
|
417
|
+
validation_items : list[Item] | None
|
|
418
|
+
Validation items.
|
|
419
|
+
validation_labels_numeric : torch.Tensor | None
|
|
420
|
+
Validation label tensor.
|
|
421
|
+
|
|
422
|
+
Returns
|
|
423
|
+
-------
|
|
424
|
+
dict[str, float]
|
|
425
|
+
Training metrics.
|
|
426
|
+
"""
|
|
427
|
+
y = labels_numeric
|
|
428
|
+
|
|
429
|
+
# Build optimizer parameters based on mode
|
|
430
|
+
params_to_optimize = list(self.encoder.parameters()) + list(
|
|
431
|
+
self.classifier_head.parameters()
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Add random effects parameters
|
|
435
|
+
if self.config.mixed_effects.mode == "random_intercepts":
|
|
436
|
+
for param_dict in self.random_effects.intercepts.values():
|
|
437
|
+
params_to_optimize.extend(param_dict.values())
|
|
438
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
439
|
+
for head in self.random_effects.slopes.values():
|
|
440
|
+
params_to_optimize.extend(head.parameters())
|
|
441
|
+
|
|
442
|
+
optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
|
|
443
|
+
# BCE with Logits Loss for multi-label classification
|
|
444
|
+
criterion = nn.BCEWithLogitsLoss()
|
|
445
|
+
|
|
446
|
+
self.encoder.train()
|
|
447
|
+
self.classifier_head.train()
|
|
448
|
+
|
|
449
|
+
epoch_acc = 0.0
|
|
450
|
+
epoch_loss = 0.0
|
|
451
|
+
|
|
452
|
+
for _epoch in range(self.config.num_epochs):
|
|
453
|
+
n_batches = (
|
|
454
|
+
len(items) + self.config.batch_size - 1
|
|
455
|
+
) // self.config.batch_size
|
|
456
|
+
epoch_loss = 0.0
|
|
457
|
+
epoch_correct_predictions = 0
|
|
458
|
+
epoch_total_predictions = 0
|
|
459
|
+
|
|
460
|
+
for i in range(n_batches):
|
|
461
|
+
start_idx = i * self.config.batch_size
|
|
462
|
+
end_idx = min(start_idx + self.config.batch_size, len(items))
|
|
463
|
+
|
|
464
|
+
batch_items = items[start_idx:end_idx]
|
|
465
|
+
batch_labels = y[start_idx:end_idx]
|
|
466
|
+
batch_participant_ids = participant_ids[start_idx:end_idx]
|
|
467
|
+
|
|
468
|
+
embeddings = self._prepare_inputs(batch_items)
|
|
469
|
+
|
|
470
|
+
# Forward pass depends on mixed effects mode
|
|
471
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
472
|
+
# Standard forward pass
|
|
473
|
+
logits = self.classifier_head(embeddings)
|
|
474
|
+
|
|
475
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
476
|
+
# Fixed head + per-participant bias (independent per option)
|
|
477
|
+
logits = self.classifier_head(embeddings)
|
|
478
|
+
for j, pid in enumerate(batch_participant_ids):
|
|
479
|
+
bias = self.random_effects.get_intercepts(
|
|
480
|
+
pid,
|
|
481
|
+
n_classes=self.num_options,
|
|
482
|
+
param_name="mu",
|
|
483
|
+
create_if_missing=True,
|
|
484
|
+
)
|
|
485
|
+
logits[j] = logits[j] + bias
|
|
486
|
+
|
|
487
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
488
|
+
# Per-participant head
|
|
489
|
+
logits_list = []
|
|
490
|
+
for j, pid in enumerate(batch_participant_ids):
|
|
491
|
+
participant_head = self.random_effects.get_slopes(
|
|
492
|
+
pid,
|
|
493
|
+
fixed_head=self.classifier_head,
|
|
494
|
+
create_if_missing=True,
|
|
495
|
+
)
|
|
496
|
+
logits_j = participant_head(embeddings[j : j + 1])
|
|
497
|
+
logits_list.append(logits_j)
|
|
498
|
+
logits = torch.cat(logits_list, dim=0)
|
|
499
|
+
|
|
500
|
+
# Data loss + prior regularization
|
|
501
|
+
loss_bce = criterion(logits, batch_labels)
|
|
502
|
+
loss_prior = self.random_effects.compute_prior_loss()
|
|
503
|
+
loss = loss_bce + loss_prior
|
|
504
|
+
|
|
505
|
+
optimizer.zero_grad()
|
|
506
|
+
loss.backward()
|
|
507
|
+
optimizer.step()
|
|
508
|
+
|
|
509
|
+
epoch_loss += loss.item()
|
|
510
|
+
|
|
511
|
+
# Predictions: threshold at 0.5 on sigmoid(logits)
|
|
512
|
+
predictions = (torch.sigmoid(logits) > 0.5).float()
|
|
513
|
+
# Hamming accuracy: fraction of correct predictions (per option)
|
|
514
|
+
batch_correct = (predictions == batch_labels).sum().item()
|
|
515
|
+
batch_total = batch_labels.numel()
|
|
516
|
+
epoch_correct_predictions += batch_correct
|
|
517
|
+
epoch_total_predictions += batch_total
|
|
518
|
+
|
|
519
|
+
# Hamming accuracy: average over all (item, option) pairs
|
|
520
|
+
epoch_acc = epoch_correct_predictions / epoch_total_predictions
|
|
521
|
+
epoch_loss = epoch_loss / n_batches
|
|
522
|
+
|
|
523
|
+
metrics: dict[str, float] = {
|
|
524
|
+
"train_accuracy": epoch_acc,
|
|
525
|
+
"train_loss": epoch_loss,
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
# Add validation accuracy if validation data provided
|
|
529
|
+
if validation_items is not None and validation_labels_numeric is not None:
|
|
530
|
+
# Validation with placeholder participant_ids for mixed effects
|
|
531
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
532
|
+
val_participant_ids = ["_fixed_"] * len(validation_items)
|
|
533
|
+
else:
|
|
534
|
+
val_participant_ids = ["_validation_"] * len(validation_items)
|
|
535
|
+
val_predictions = self._do_predict(validation_items, val_participant_ids)
|
|
536
|
+
|
|
537
|
+
# Parse validation labels
|
|
538
|
+
val_labels_parsed = []
|
|
539
|
+
for i in range(validation_labels_numeric.shape[0]):
|
|
540
|
+
selected = [
|
|
541
|
+
self.option_names[j]
|
|
542
|
+
for j in range(self.num_options)
|
|
543
|
+
if validation_labels_numeric[i, j] > 0.5
|
|
544
|
+
]
|
|
545
|
+
val_labels_parsed.append(set(selected))
|
|
546
|
+
|
|
547
|
+
# Compute Hamming accuracy
|
|
548
|
+
val_correct = 0
|
|
549
|
+
val_total = 0
|
|
550
|
+
for pred, true_set in zip(val_predictions, val_labels_parsed, strict=True):
|
|
551
|
+
# pred.predicted_class is JSON string of selected options
|
|
552
|
+
pred_set = set(json.loads(pred.predicted_class))
|
|
553
|
+
for opt in self.option_names:
|
|
554
|
+
if (opt in pred_set) == (opt in true_set):
|
|
555
|
+
val_correct += 1
|
|
556
|
+
val_total += 1
|
|
557
|
+
|
|
558
|
+
val_acc = val_correct / val_total
|
|
559
|
+
metrics["val_accuracy"] = val_acc
|
|
560
|
+
|
|
561
|
+
return metrics
|
|
562
|
+
|
|
563
|
+
def _do_predict(
|
|
564
|
+
self, items: list[Item], participant_ids: list[str]
|
|
565
|
+
) -> list[ModelPrediction]:
|
|
566
|
+
"""Perform multi-select model prediction.
|
|
567
|
+
|
|
568
|
+
Parameters
|
|
569
|
+
----------
|
|
570
|
+
items : list[Item]
|
|
571
|
+
Items to predict.
|
|
572
|
+
participant_ids : list[str]
|
|
573
|
+
Normalized participant IDs.
|
|
574
|
+
|
|
575
|
+
Returns
|
|
576
|
+
-------
|
|
577
|
+
list[ModelPrediction]
|
|
578
|
+
Predictions.
|
|
579
|
+
"""
|
|
580
|
+
self.encoder.eval()
|
|
581
|
+
self.classifier_head.eval()
|
|
582
|
+
|
|
583
|
+
with torch.no_grad():
|
|
584
|
+
embeddings = self._prepare_inputs(items)
|
|
585
|
+
|
|
586
|
+
# Forward pass depends on mixed effects mode
|
|
587
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
588
|
+
logits = self.classifier_head(embeddings)
|
|
589
|
+
|
|
590
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
591
|
+
logits = self.classifier_head(embeddings)
|
|
592
|
+
for i, pid in enumerate(participant_ids):
|
|
593
|
+
# Unknown participants: use prior mean (zero bias)
|
|
594
|
+
bias = self.random_effects.get_intercepts(
|
|
595
|
+
pid,
|
|
596
|
+
n_classes=self.num_options,
|
|
597
|
+
param_name="mu",
|
|
598
|
+
create_if_missing=False,
|
|
599
|
+
)
|
|
600
|
+
logits[i] = logits[i] + bias
|
|
601
|
+
|
|
602
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
603
|
+
logits_list = []
|
|
604
|
+
for i, pid in enumerate(participant_ids):
|
|
605
|
+
# Unknown participants: use fixed head
|
|
606
|
+
participant_head = self.random_effects.get_slopes(
|
|
607
|
+
pid, fixed_head=self.classifier_head, create_if_missing=False
|
|
608
|
+
)
|
|
609
|
+
logits_i = participant_head(embeddings[i : i + 1])
|
|
610
|
+
logits_list.append(logits_i)
|
|
611
|
+
logits = torch.cat(logits_list, dim=0)
|
|
612
|
+
|
|
613
|
+
# Compute probabilities using sigmoid
|
|
614
|
+
proba = torch.sigmoid(logits).cpu().numpy() # (n_items, n_options)
|
|
615
|
+
pred_binary = proba > 0.5 # Threshold at 0.5
|
|
616
|
+
|
|
617
|
+
predictions = []
|
|
618
|
+
for i, item in enumerate(items):
|
|
619
|
+
# Determine selected options
|
|
620
|
+
selected_options = [
|
|
621
|
+
self.option_names[j]
|
|
622
|
+
for j in range(self.num_options)
|
|
623
|
+
if pred_binary[i, j]
|
|
624
|
+
]
|
|
625
|
+
|
|
626
|
+
# Build probability dict: {option: probability}
|
|
627
|
+
prob_dict = {
|
|
628
|
+
opt: float(proba[i, idx]) for idx, opt in enumerate(self.option_names)
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
# Confidence: average probability of selected options (or 0.5 if none)
|
|
632
|
+
if selected_options:
|
|
633
|
+
option_probs = [
|
|
634
|
+
proba[i, self.option_names.index(opt)] for opt in selected_options
|
|
635
|
+
]
|
|
636
|
+
confidence = float(np.mean(option_probs))
|
|
637
|
+
else:
|
|
638
|
+
confidence = 0.5 # Neutral confidence when nothing selected
|
|
639
|
+
|
|
640
|
+
predictions.append(
|
|
641
|
+
ModelPrediction(
|
|
642
|
+
item_id=str(item.id),
|
|
643
|
+
probabilities=prob_dict,
|
|
644
|
+
predicted_class=json.dumps(sorted(selected_options)),
|
|
645
|
+
confidence=confidence,
|
|
646
|
+
)
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
return predictions
|
|
650
|
+
|
|
651
|
+
def _do_predict_proba(
|
|
652
|
+
self, items: list[Item], participant_ids: list[str]
|
|
653
|
+
) -> np.ndarray:
|
|
654
|
+
"""Perform multi-select model probability prediction.
|
|
655
|
+
|
|
656
|
+
Parameters
|
|
657
|
+
----------
|
|
658
|
+
items : list[Item]
|
|
659
|
+
Items to predict.
|
|
660
|
+
participant_ids : list[str]
|
|
661
|
+
Normalized participant IDs.
|
|
662
|
+
|
|
663
|
+
Returns
|
|
664
|
+
-------
|
|
665
|
+
np.ndarray
|
|
666
|
+
Probability array of shape (n_items, n_options).
|
|
667
|
+
"""
|
|
668
|
+
self.encoder.eval()
|
|
669
|
+
self.classifier_head.eval()
|
|
670
|
+
|
|
671
|
+
with torch.no_grad():
|
|
672
|
+
embeddings = self._prepare_inputs(items)
|
|
673
|
+
|
|
674
|
+
# Forward pass depends on mixed effects mode
|
|
675
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
676
|
+
logits = self.classifier_head(embeddings)
|
|
677
|
+
|
|
678
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
679
|
+
logits = self.classifier_head(embeddings)
|
|
680
|
+
for i, pid in enumerate(participant_ids):
|
|
681
|
+
bias = self.random_effects.get_intercepts(
|
|
682
|
+
pid,
|
|
683
|
+
n_classes=self.num_options,
|
|
684
|
+
param_name="mu",
|
|
685
|
+
create_if_missing=False,
|
|
686
|
+
)
|
|
687
|
+
logits[i] = logits[i] + bias
|
|
688
|
+
|
|
689
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
690
|
+
logits_list = []
|
|
691
|
+
for i, pid in enumerate(participant_ids):
|
|
692
|
+
participant_head = self.random_effects.get_slopes(
|
|
693
|
+
pid, fixed_head=self.classifier_head, create_if_missing=False
|
|
694
|
+
)
|
|
695
|
+
logits_i = participant_head(embeddings[i : i + 1])
|
|
696
|
+
logits_list.append(logits_i)
|
|
697
|
+
logits = torch.cat(logits_list, dim=0)
|
|
698
|
+
|
|
699
|
+
# Compute probabilities using sigmoid
|
|
700
|
+
proba = torch.sigmoid(logits).cpu().numpy()
|
|
701
|
+
|
|
702
|
+
return proba
|
|
703
|
+
|
|
704
|
+
def _get_save_state(self) -> dict[str, object]:
|
|
705
|
+
"""Get model-specific state to save.
|
|
706
|
+
|
|
707
|
+
Returns
|
|
708
|
+
-------
|
|
709
|
+
dict[str, object]
|
|
710
|
+
State dictionary.
|
|
711
|
+
"""
|
|
712
|
+
return {
|
|
713
|
+
"num_options": self.num_options,
|
|
714
|
+
"option_names": self.option_names,
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
def _save_model_components(self, save_path: Path) -> None:
|
|
718
|
+
"""Save model-specific components.
|
|
719
|
+
|
|
720
|
+
Parameters
|
|
721
|
+
----------
|
|
722
|
+
save_path : Path
|
|
723
|
+
Directory to save to.
|
|
724
|
+
"""
|
|
725
|
+
self.encoder.save_pretrained(save_path / "encoder")
|
|
726
|
+
self.tokenizer.save_pretrained(save_path / "encoder")
|
|
727
|
+
|
|
728
|
+
torch.save(
|
|
729
|
+
self.classifier_head.state_dict(),
|
|
730
|
+
save_path / "classifier_head.pt",
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
def _restore_training_state(self, config_dict: dict[str, object]) -> None:
|
|
734
|
+
"""Restore model-specific training state.
|
|
735
|
+
|
|
736
|
+
Parameters
|
|
737
|
+
----------
|
|
738
|
+
config_dict : dict[str, object]
|
|
739
|
+
Configuration dictionary with training state.
|
|
740
|
+
"""
|
|
741
|
+
self.num_options = config_dict.pop("num_options")
|
|
742
|
+
self.option_names = config_dict.pop("option_names")
|
|
743
|
+
|
|
744
|
+
def _load_model_components(self, load_path: Path) -> None:
|
|
745
|
+
"""Load model-specific components.
|
|
746
|
+
|
|
747
|
+
Parameters
|
|
748
|
+
----------
|
|
749
|
+
load_path : Path
|
|
750
|
+
Directory to load from.
|
|
751
|
+
"""
|
|
752
|
+
# Load config.json to reconstruct config
|
|
753
|
+
with open(load_path / "config.json") as f:
|
|
754
|
+
config_dict = json.load(f)
|
|
755
|
+
|
|
756
|
+
# Reconstruct MixedEffectsConfig if needed
|
|
757
|
+
if "mixed_effects" in config_dict and isinstance(
|
|
758
|
+
config_dict["mixed_effects"], dict
|
|
759
|
+
):
|
|
760
|
+
config_dict["mixed_effects"] = MixedEffectsConfig(
|
|
761
|
+
**config_dict["mixed_effects"]
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
self.config = MultiSelectModelConfig(**config_dict)
|
|
765
|
+
|
|
766
|
+
self.encoder = AutoModel.from_pretrained(load_path / "encoder")
|
|
767
|
+
self.tokenizer = AutoTokenizer.from_pretrained(load_path / "encoder")
|
|
768
|
+
|
|
769
|
+
self._initialize_classifier(self.num_options)
|
|
770
|
+
self.classifier_head.load_state_dict(
|
|
771
|
+
torch.load(
|
|
772
|
+
load_path / "classifier_head.pt", map_location=self.config.device
|
|
773
|
+
)
|
|
774
|
+
)
|
|
775
|
+
self.classifier_head.to(self.config.device)
|
|
776
|
+
|
|
777
|
+
def _get_random_effects_fixed_head(self) -> nn.Sequential | None:
|
|
778
|
+
"""Get fixed head for random effects loading.
|
|
779
|
+
|
|
780
|
+
Returns
|
|
781
|
+
-------
|
|
782
|
+
nn.Sequential | None
|
|
783
|
+
Fixed head module.
|
|
784
|
+
"""
|
|
785
|
+
return self.classifier_head
|
|
786
|
+
|
|
787
|
+
def _get_n_classes_for_random_effects(self) -> int:
|
|
788
|
+
"""Get number of classes for random effects initialization.
|
|
789
|
+
|
|
790
|
+
Returns
|
|
791
|
+
-------
|
|
792
|
+
int
|
|
793
|
+
Number of options.
|
|
794
|
+
"""
|
|
795
|
+
return self.num_options
|