bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,773 @@
|
|
|
1
|
+
"""Free text model for open-ended text generation with GLMM support.
|
|
2
|
+
|
|
3
|
+
Implements seq2seq generation with participant-level random effects using:
|
|
4
|
+
- Random intercepts: Bias on decoder output logits (token probability shifts)
|
|
5
|
+
- Random slopes: LoRA adapters on decoder attention layers
|
|
6
|
+
|
|
7
|
+
Architecture: T5-base or BART-base encoder-decoder model
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn.functional
|
|
18
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
19
|
+
|
|
20
|
+
from bead.active_learning.config import VarianceComponents
|
|
21
|
+
from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
|
|
22
|
+
from bead.active_learning.models.peft_adapter import create_participant_lora_adapter
|
|
23
|
+
from bead.active_learning.models.random_effects import RandomEffectsManager
|
|
24
|
+
from bead.config.active_learning import FreeTextModelConfig
|
|
25
|
+
from bead.items.item import Item
|
|
26
|
+
from bead.items.item_template import ItemTemplate, TaskType
|
|
27
|
+
|
|
28
|
+
__all__ = ["FreeTextModel"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FreeTextModel(ActiveLearningModel):
|
|
32
|
+
"""Model for free_text tasks with participant-level random effects.
|
|
33
|
+
|
|
34
|
+
Uses seq2seq architecture (T5 or BART) with three modes:
|
|
35
|
+
- Fixed effects: Standard encoder-decoder
|
|
36
|
+
- Random intercepts: Participant-specific bias on output logits
|
|
37
|
+
- Random slopes: Participant-specific LoRA adapters on decoder
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
config : FreeTextModelConfig
|
|
42
|
+
Configuration object containing all model parameters.
|
|
43
|
+
|
|
44
|
+
Attributes
|
|
45
|
+
----------
|
|
46
|
+
config : FreeTextModelConfig
|
|
47
|
+
Model configuration.
|
|
48
|
+
tokenizer : AutoTokenizer
|
|
49
|
+
Seq2seq tokenizer.
|
|
50
|
+
model : AutoModelForSeq2SeqLM
|
|
51
|
+
Base seq2seq model (T5 or BART).
|
|
52
|
+
encoder : nn.Module
|
|
53
|
+
Encoder module.
|
|
54
|
+
base_decoder : nn.Module
|
|
55
|
+
Base decoder module (shared across participants in fixed/random_intercepts).
|
|
56
|
+
lm_head : nn.Module
|
|
57
|
+
Language modeling head (projects decoder output to vocabulary).
|
|
58
|
+
random_effects : RandomEffectsManager
|
|
59
|
+
Manager for participant-level random effects.
|
|
60
|
+
variance_history : list[VarianceComponents]
|
|
61
|
+
Variance component estimates over training.
|
|
62
|
+
_is_fitted : bool
|
|
63
|
+
Whether model has been trained.
|
|
64
|
+
|
|
65
|
+
Examples
|
|
66
|
+
--------
|
|
67
|
+
>>> from uuid import uuid4
|
|
68
|
+
>>> from bead.items.item import Item
|
|
69
|
+
>>> from bead.config.active_learning import FreeTextModelConfig
|
|
70
|
+
>>> items = [
|
|
71
|
+
... Item(
|
|
72
|
+
... item_template_id=uuid4(),
|
|
73
|
+
... rendered_elements={"prompt": "Summarize: The cat sat."}
|
|
74
|
+
... )
|
|
75
|
+
... for _ in range(10)
|
|
76
|
+
... ]
|
|
77
|
+
>>> labels = ["Cat sits."] * 10
|
|
78
|
+
>>> config = FreeTextModelConfig( # doctest: +SKIP
|
|
79
|
+
... num_epochs=1, batch_size=2, device="cpu"
|
|
80
|
+
... )
|
|
81
|
+
>>> model = FreeTextModel(config=config) # doctest: +SKIP
|
|
82
|
+
>>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
config: FreeTextModelConfig | None = None,
|
|
88
|
+
) -> None:
|
|
89
|
+
"""Initialize free text model.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
config : FreeTextModelConfig | None
|
|
94
|
+
Configuration object. If None, uses default configuration.
|
|
95
|
+
"""
|
|
96
|
+
self.config = config or FreeTextModelConfig()
|
|
97
|
+
|
|
98
|
+
# Validate mixed_effects configuration
|
|
99
|
+
super().__init__(self.config)
|
|
100
|
+
|
|
101
|
+
# Load tokenizer and model
|
|
102
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
|
103
|
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_name)
|
|
104
|
+
|
|
105
|
+
# Extract encoder, decoder, and lm_head for fine-grained control
|
|
106
|
+
self.encoder = self.model.get_encoder()
|
|
107
|
+
self.base_decoder = self.model.get_decoder()
|
|
108
|
+
self.lm_head = self.model.lm_head
|
|
109
|
+
|
|
110
|
+
self._is_fitted = False
|
|
111
|
+
|
|
112
|
+
# Initialize random effects manager
|
|
113
|
+
self.random_effects: RandomEffectsManager | None = None
|
|
114
|
+
self.variance_history: list[VarianceComponents] = []
|
|
115
|
+
|
|
116
|
+
self.model.to(self.config.device)
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def supported_task_types(self) -> list[TaskType]:
|
|
120
|
+
"""Get supported task types.
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
list[TaskType]
|
|
125
|
+
List containing "free_text".
|
|
126
|
+
"""
|
|
127
|
+
return ["free_text"]
|
|
128
|
+
|
|
129
|
+
def validate_item_compatibility(
|
|
130
|
+
self, item: Item, item_template: ItemTemplate
|
|
131
|
+
) -> None:
|
|
132
|
+
"""Validate item is compatible with free text model.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
item : Item
|
|
137
|
+
Item to validate.
|
|
138
|
+
item_template : ItemTemplate
|
|
139
|
+
Template the item was constructed from.
|
|
140
|
+
|
|
141
|
+
Raises
|
|
142
|
+
------
|
|
143
|
+
ValueError
|
|
144
|
+
If task_type is not "free_text".
|
|
145
|
+
"""
|
|
146
|
+
if item_template.task_type != "free_text":
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Expected task_type 'free_text', got '{item_template.task_type}'"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def _prepare_inputs(self, items: list[Item]) -> str:
|
|
152
|
+
"""Prepare input texts from items.
|
|
153
|
+
|
|
154
|
+
For free text tasks, concatenates all rendered elements as prompt.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
items : list[Item]
|
|
159
|
+
Items to encode.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
list[str]
|
|
164
|
+
Input texts.
|
|
165
|
+
"""
|
|
166
|
+
texts = []
|
|
167
|
+
for item in items:
|
|
168
|
+
# Concatenate all rendered elements as input
|
|
169
|
+
text = " ".join(item.rendered_elements.values())
|
|
170
|
+
texts.append(text)
|
|
171
|
+
return texts
|
|
172
|
+
|
|
173
|
+
def _prepare_training_data(
|
|
174
|
+
self,
|
|
175
|
+
items: list[Item],
|
|
176
|
+
labels: list[str],
|
|
177
|
+
participant_ids: list[str],
|
|
178
|
+
validation_items: list[Item] | None,
|
|
179
|
+
validation_labels: list[str] | None,
|
|
180
|
+
) -> tuple[
|
|
181
|
+
list[Item],
|
|
182
|
+
list[str],
|
|
183
|
+
list[str],
|
|
184
|
+
list[Item] | None,
|
|
185
|
+
list[str] | None,
|
|
186
|
+
]:
|
|
187
|
+
"""Prepare data for training, including validation.
|
|
188
|
+
|
|
189
|
+
Parameters
|
|
190
|
+
----------
|
|
191
|
+
items : list[Item]
|
|
192
|
+
Training items.
|
|
193
|
+
labels : list[str]
|
|
194
|
+
Training labels (target text strings).
|
|
195
|
+
participant_ids : list[str]
|
|
196
|
+
Participant identifiers.
|
|
197
|
+
validation_items : list[Item] | None
|
|
198
|
+
Optional validation items.
|
|
199
|
+
validation_labels : list[str] | None
|
|
200
|
+
Optional validation labels.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
tuple
|
|
205
|
+
Prepared training data: items, labels, participant_ids,
|
|
206
|
+
validation_items, validation_labels.
|
|
207
|
+
|
|
208
|
+
Raises
|
|
209
|
+
------
|
|
210
|
+
ValueError
|
|
211
|
+
If labels contain empty strings.
|
|
212
|
+
"""
|
|
213
|
+
if any(not label for label in labels):
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"labels cannot contain empty strings. "
|
|
216
|
+
"Ensure all labels are non-empty text."
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
val_labels_list: list[str] | None = None
|
|
220
|
+
if validation_items is not None and validation_labels is not None:
|
|
221
|
+
if any(not label for label in validation_labels):
|
|
222
|
+
raise ValueError(
|
|
223
|
+
"validation_labels cannot contain empty strings. "
|
|
224
|
+
"Ensure all validation labels are non-empty text."
|
|
225
|
+
)
|
|
226
|
+
val_labels_list = validation_labels
|
|
227
|
+
|
|
228
|
+
return items, labels, participant_ids, validation_items, val_labels_list
|
|
229
|
+
|
|
230
|
+
def _do_training(
|
|
231
|
+
self,
|
|
232
|
+
items: list[Item],
|
|
233
|
+
labels_numeric: list[str],
|
|
234
|
+
participant_ids: list[str],
|
|
235
|
+
validation_items: list[Item] | None,
|
|
236
|
+
validation_labels_numeric: list[str] | None,
|
|
237
|
+
) -> dict[str, float]:
|
|
238
|
+
"""Perform the actual training logic (custom loop for seq2seq).
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
items : list[Item]
|
|
243
|
+
Training items.
|
|
244
|
+
labels_numeric : list[str]
|
|
245
|
+
Training labels (target text strings).
|
|
246
|
+
participant_ids : list[str]
|
|
247
|
+
Participant identifiers.
|
|
248
|
+
validation_items : list[Item] | None
|
|
249
|
+
Optional validation items.
|
|
250
|
+
validation_labels_numeric : list[str] | None
|
|
251
|
+
Optional validation labels.
|
|
252
|
+
|
|
253
|
+
Returns
|
|
254
|
+
-------
|
|
255
|
+
dict[str, float]
|
|
256
|
+
Training metrics.
|
|
257
|
+
"""
|
|
258
|
+
# Prepare inputs
|
|
259
|
+
input_texts = self._prepare_inputs(items)
|
|
260
|
+
|
|
261
|
+
# Get actual vocabulary size from lm_head output dimension
|
|
262
|
+
vocab_size = self.lm_head.out_features
|
|
263
|
+
|
|
264
|
+
# Build optimizer parameters based on mode
|
|
265
|
+
params_to_optimize = list(self.model.parameters())
|
|
266
|
+
|
|
267
|
+
# Add random effects parameters
|
|
268
|
+
if self.config.mixed_effects.mode == "random_intercepts":
|
|
269
|
+
for param_dict in self.random_effects.intercepts.values():
|
|
270
|
+
params_to_optimize.extend(param_dict.values())
|
|
271
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
272
|
+
for adapter in self.random_effects.slopes.values():
|
|
273
|
+
params_to_optimize.extend(adapter.get_lora_parameters())
|
|
274
|
+
|
|
275
|
+
optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
|
|
276
|
+
|
|
277
|
+
self.model.train()
|
|
278
|
+
|
|
279
|
+
for _epoch in range(self.config.num_epochs):
|
|
280
|
+
n_batches = (
|
|
281
|
+
len(items) + self.config.batch_size - 1
|
|
282
|
+
) // self.config.batch_size
|
|
283
|
+
epoch_loss = 0.0
|
|
284
|
+
|
|
285
|
+
for i in range(n_batches):
|
|
286
|
+
start_idx = i * self.config.batch_size
|
|
287
|
+
end_idx = min(start_idx + self.config.batch_size, len(items))
|
|
288
|
+
|
|
289
|
+
batch_input_texts = input_texts[start_idx:end_idx]
|
|
290
|
+
batch_labels = labels_numeric[start_idx:end_idx]
|
|
291
|
+
batch_participant_ids = participant_ids[start_idx:end_idx]
|
|
292
|
+
|
|
293
|
+
# Tokenize inputs and labels
|
|
294
|
+
inputs = self.tokenizer(
|
|
295
|
+
batch_input_texts,
|
|
296
|
+
padding=True,
|
|
297
|
+
truncation=True,
|
|
298
|
+
max_length=self.config.max_input_length,
|
|
299
|
+
return_tensors="pt",
|
|
300
|
+
).to(self.config.device)
|
|
301
|
+
|
|
302
|
+
# Tokenize targets (labels)
|
|
303
|
+
targets = self.tokenizer(
|
|
304
|
+
text_target=batch_labels,
|
|
305
|
+
padding=True,
|
|
306
|
+
truncation=True,
|
|
307
|
+
max_length=self.config.max_output_length,
|
|
308
|
+
return_tensors="pt",
|
|
309
|
+
).to(self.config.device)
|
|
310
|
+
|
|
311
|
+
target_ids = targets["input_ids"]
|
|
312
|
+
# Replace pad token id with -100 for loss computation
|
|
313
|
+
target_ids[target_ids == self.tokenizer.pad_token_id] = -100
|
|
314
|
+
|
|
315
|
+
# Forward pass depends on mixed effects mode
|
|
316
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
317
|
+
# Standard seq2seq training
|
|
318
|
+
outputs = self.model(
|
|
319
|
+
**inputs,
|
|
320
|
+
labels=target_ids,
|
|
321
|
+
)
|
|
322
|
+
loss_nll = outputs.loss
|
|
323
|
+
|
|
324
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
325
|
+
# Get encoder outputs
|
|
326
|
+
encoder_outputs = self.encoder(**inputs)
|
|
327
|
+
|
|
328
|
+
# Run decoder to get logits
|
|
329
|
+
decoder_outputs = self.base_decoder(
|
|
330
|
+
input_ids=targets["input_ids"],
|
|
331
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
332
|
+
encoder_attention_mask=inputs["attention_mask"],
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Project to vocabulary
|
|
336
|
+
logits = self.lm_head(decoder_outputs.last_hidden_state)
|
|
337
|
+
|
|
338
|
+
# Add participant-specific bias to logits
|
|
339
|
+
for j, pid in enumerate(batch_participant_ids):
|
|
340
|
+
bias = self.random_effects.get_intercepts(
|
|
341
|
+
pid,
|
|
342
|
+
n_classes=vocab_size,
|
|
343
|
+
param_name="mu",
|
|
344
|
+
create_if_missing=True,
|
|
345
|
+
)
|
|
346
|
+
# bias shape: (vocab_size,)
|
|
347
|
+
# Add to all positions in sequence
|
|
348
|
+
logits[j] = logits[j] + bias
|
|
349
|
+
|
|
350
|
+
# Compute cross-entropy loss
|
|
351
|
+
loss_nll = torch.nn.functional.cross_entropy(
|
|
352
|
+
logits.view(-1, vocab_size),
|
|
353
|
+
target_ids.view(-1),
|
|
354
|
+
ignore_index=-100,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
358
|
+
# Use participant-specific LoRA adapters
|
|
359
|
+
# Need to process each participant separately
|
|
360
|
+
losses = []
|
|
361
|
+
for j, pid in enumerate(batch_participant_ids):
|
|
362
|
+
# Get participant-specific decoder
|
|
363
|
+
participant_decoder = self.random_effects.get_slopes(
|
|
364
|
+
pid,
|
|
365
|
+
fixed_head=create_participant_lora_adapter(
|
|
366
|
+
self.base_decoder,
|
|
367
|
+
rank=self.config.lora_rank,
|
|
368
|
+
alpha=self.config.lora_alpha,
|
|
369
|
+
dropout=self.config.lora_dropout,
|
|
370
|
+
target_modules=self.config.lora_target_modules,
|
|
371
|
+
),
|
|
372
|
+
create_if_missing=True,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Get encoder outputs for this item
|
|
376
|
+
item_inputs = {k: v[j : j + 1] for k, v in inputs.items()}
|
|
377
|
+
encoder_outputs_j = self.encoder(**item_inputs)
|
|
378
|
+
|
|
379
|
+
# Run participant-specific decoder
|
|
380
|
+
decoder_outputs_j = participant_decoder(
|
|
381
|
+
input_ids=targets["input_ids"][j : j + 1],
|
|
382
|
+
encoder_hidden_states=encoder_outputs_j.last_hidden_state,
|
|
383
|
+
encoder_attention_mask=item_inputs["attention_mask"],
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Project to vocabulary
|
|
387
|
+
logits_j = self.lm_head(decoder_outputs_j.last_hidden_state)
|
|
388
|
+
|
|
389
|
+
# Compute loss for this item
|
|
390
|
+
loss_j = torch.nn.functional.cross_entropy(
|
|
391
|
+
logits_j.view(-1, vocab_size),
|
|
392
|
+
target_ids[j : j + 1].view(-1),
|
|
393
|
+
ignore_index=-100,
|
|
394
|
+
)
|
|
395
|
+
losses.append(loss_j)
|
|
396
|
+
|
|
397
|
+
loss_nll = torch.stack(losses).mean()
|
|
398
|
+
|
|
399
|
+
# Add prior regularization
|
|
400
|
+
loss_prior = self.random_effects.compute_prior_loss()
|
|
401
|
+
loss = loss_nll + loss_prior
|
|
402
|
+
|
|
403
|
+
optimizer.zero_grad()
|
|
404
|
+
loss.backward()
|
|
405
|
+
optimizer.step()
|
|
406
|
+
|
|
407
|
+
epoch_loss += loss.item()
|
|
408
|
+
|
|
409
|
+
epoch_loss = epoch_loss / n_batches
|
|
410
|
+
|
|
411
|
+
metrics: dict[str, float] = {
|
|
412
|
+
"train_loss": epoch_loss,
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
# Estimate variance components
|
|
416
|
+
if self.config.mixed_effects.estimate_variance_components:
|
|
417
|
+
var_comps = self.random_effects.estimate_variance_components()
|
|
418
|
+
if var_comps:
|
|
419
|
+
var_comp = var_comps.get("mu") or var_comps.get("slopes")
|
|
420
|
+
if var_comp:
|
|
421
|
+
if not hasattr(self, "variance_history"):
|
|
422
|
+
self.variance_history = []
|
|
423
|
+
self.variance_history.append(var_comp)
|
|
424
|
+
metrics["participant_variance"] = var_comp.variance
|
|
425
|
+
metrics["n_participants"] = var_comp.n_groups
|
|
426
|
+
|
|
427
|
+
# Compute training exact match
|
|
428
|
+
train_predictions = self._do_predict(items, participant_ids)
|
|
429
|
+
train_pred_texts = [p.predicted_class for p in train_predictions]
|
|
430
|
+
metrics["train_exact_match"] = self._compute_exact_match(
|
|
431
|
+
train_pred_texts, labels_numeric
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
if validation_items is not None and validation_labels_numeric is not None:
|
|
435
|
+
# Validation
|
|
436
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
437
|
+
val_participant_ids = ["_fixed_"] * len(validation_items)
|
|
438
|
+
else:
|
|
439
|
+
val_participant_ids = ["_validation_"] * len(validation_items)
|
|
440
|
+
val_predictions = self._do_predict(validation_items, val_participant_ids)
|
|
441
|
+
|
|
442
|
+
val_pred_texts = [p.predicted_class for p in val_predictions]
|
|
443
|
+
metrics["val_exact_match"] = self._compute_exact_match(
|
|
444
|
+
val_pred_texts, validation_labels_numeric
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
return metrics
|
|
448
|
+
|
|
449
|
+
def _do_predict(
|
|
450
|
+
self, items: list[Item], participant_ids: list[str]
|
|
451
|
+
) -> list[ModelPrediction]:
|
|
452
|
+
"""Generate text for items with participant-specific random effects.
|
|
453
|
+
|
|
454
|
+
Parameters
|
|
455
|
+
----------
|
|
456
|
+
items : list[Item]
|
|
457
|
+
Items to predict.
|
|
458
|
+
participant_ids : list[str]
|
|
459
|
+
Participant identifiers.
|
|
460
|
+
|
|
461
|
+
Returns
|
|
462
|
+
-------
|
|
463
|
+
list[ModelPrediction]
|
|
464
|
+
Predictions with predicted_class as generated text.
|
|
465
|
+
"""
|
|
466
|
+
self.model.eval()
|
|
467
|
+
|
|
468
|
+
input_texts = self._prepare_inputs(items)
|
|
469
|
+
|
|
470
|
+
# Tokenize inputs
|
|
471
|
+
inputs = self.tokenizer(
|
|
472
|
+
input_texts,
|
|
473
|
+
padding=True,
|
|
474
|
+
truncation=True,
|
|
475
|
+
max_length=self.config.max_input_length,
|
|
476
|
+
return_tensors="pt",
|
|
477
|
+
).to(self.config.device)
|
|
478
|
+
|
|
479
|
+
with torch.no_grad():
|
|
480
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
481
|
+
# Standard generation
|
|
482
|
+
outputs = self.model.generate(
|
|
483
|
+
**inputs,
|
|
484
|
+
max_length=self.config.max_output_length,
|
|
485
|
+
num_beams=self.config.num_beams,
|
|
486
|
+
temperature=self.config.temperature,
|
|
487
|
+
top_p=self.config.top_p,
|
|
488
|
+
)
|
|
489
|
+
generated_texts = self.tokenizer.batch_decode(
|
|
490
|
+
outputs, skip_special_tokens=True
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
494
|
+
# Generate with participant-specific bias
|
|
495
|
+
# For simplicity, use greedy decoding with bias applied at each step
|
|
496
|
+
# (Full beam search with bias is more complex)
|
|
497
|
+
generated_texts = []
|
|
498
|
+
vocab_size = self.lm_head.out_features
|
|
499
|
+
|
|
500
|
+
for i, pid in enumerate(participant_ids):
|
|
501
|
+
# Get encoder outputs for this item
|
|
502
|
+
item_inputs = {k: v[i : i + 1] for k, v in inputs.items()}
|
|
503
|
+
encoder_outputs = self.encoder(**item_inputs)
|
|
504
|
+
|
|
505
|
+
# Get participant bias
|
|
506
|
+
bias = self.random_effects.get_intercepts(
|
|
507
|
+
pid,
|
|
508
|
+
n_classes=vocab_size,
|
|
509
|
+
param_name="mu",
|
|
510
|
+
create_if_missing=False,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Greedy decoding with bias
|
|
514
|
+
decoder_input_ids = torch.tensor(
|
|
515
|
+
[[self.tokenizer.pad_token_id]], device=self.config.device
|
|
516
|
+
)
|
|
517
|
+
generated_ids = []
|
|
518
|
+
|
|
519
|
+
for _ in range(self.config.max_output_length):
|
|
520
|
+
decoder_outputs = self.base_decoder(
|
|
521
|
+
input_ids=decoder_input_ids,
|
|
522
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
523
|
+
encoder_attention_mask=item_inputs["attention_mask"],
|
|
524
|
+
)
|
|
525
|
+
logits = self.lm_head(
|
|
526
|
+
decoder_outputs.last_hidden_state[:, -1, :]
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Add participant bias (bias is 1D, logits is 2D)
|
|
530
|
+
logits = logits + bias.unsqueeze(0)
|
|
531
|
+
|
|
532
|
+
# Greedy selection
|
|
533
|
+
next_token_id = torch.argmax(logits, dim=-1)
|
|
534
|
+
generated_ids.append(next_token_id.item())
|
|
535
|
+
|
|
536
|
+
# Stop if EOS
|
|
537
|
+
if next_token_id.item() == self.tokenizer.eos_token_id:
|
|
538
|
+
break
|
|
539
|
+
|
|
540
|
+
# Append to decoder input (scalar after argmax)
|
|
541
|
+
decoder_input_ids = torch.cat(
|
|
542
|
+
[decoder_input_ids, next_token_id.unsqueeze(-1)], dim=1
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
# Decode generated text
|
|
546
|
+
text = self.tokenizer.decode(
|
|
547
|
+
generated_ids, skip_special_tokens=True
|
|
548
|
+
)
|
|
549
|
+
generated_texts.append(text)
|
|
550
|
+
|
|
551
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
552
|
+
# Generate with participant-specific LoRA decoder
|
|
553
|
+
generated_texts = []
|
|
554
|
+
|
|
555
|
+
for i, pid in enumerate(participant_ids):
|
|
556
|
+
# Get participant-specific decoder
|
|
557
|
+
participant_decoder = self.random_effects.get_slopes(
|
|
558
|
+
pid,
|
|
559
|
+
fixed_head=create_participant_lora_adapter(
|
|
560
|
+
self.base_decoder,
|
|
561
|
+
rank=self.config.lora_rank,
|
|
562
|
+
alpha=self.config.lora_alpha,
|
|
563
|
+
dropout=self.config.lora_dropout,
|
|
564
|
+
target_modules=self.config.lora_target_modules,
|
|
565
|
+
),
|
|
566
|
+
create_if_missing=False,
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
# Get encoder outputs
|
|
570
|
+
item_inputs = {k: v[i : i + 1] for k, v in inputs.items()}
|
|
571
|
+
encoder_outputs = self.encoder(**item_inputs)
|
|
572
|
+
|
|
573
|
+
# Greedy decoding with participant decoder
|
|
574
|
+
decoder_input_ids = torch.tensor(
|
|
575
|
+
[[self.tokenizer.pad_token_id]], device=self.config.device
|
|
576
|
+
)
|
|
577
|
+
generated_ids = []
|
|
578
|
+
|
|
579
|
+
for _ in range(self.config.max_output_length):
|
|
580
|
+
decoder_outputs = participant_decoder(
|
|
581
|
+
input_ids=decoder_input_ids,
|
|
582
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
583
|
+
encoder_attention_mask=item_inputs["attention_mask"],
|
|
584
|
+
)
|
|
585
|
+
logits = self.lm_head(
|
|
586
|
+
decoder_outputs.last_hidden_state[:, -1, :]
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
next_token_id = torch.argmax(logits, dim=-1)
|
|
590
|
+
generated_ids.append(next_token_id.item())
|
|
591
|
+
|
|
592
|
+
if next_token_id.item() == self.tokenizer.eos_token_id:
|
|
593
|
+
break
|
|
594
|
+
|
|
595
|
+
decoder_input_ids = torch.cat(
|
|
596
|
+
[decoder_input_ids, next_token_id.unsqueeze(-1)], dim=1
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
text = self.tokenizer.decode(
|
|
600
|
+
generated_ids, skip_special_tokens=True
|
|
601
|
+
)
|
|
602
|
+
generated_texts.append(text)
|
|
603
|
+
|
|
604
|
+
predictions = []
|
|
605
|
+
for i, item in enumerate(items):
|
|
606
|
+
predictions.append(
|
|
607
|
+
ModelPrediction(
|
|
608
|
+
item_id=str(item.id),
|
|
609
|
+
probabilities={}, # Not applicable for generation
|
|
610
|
+
predicted_class=generated_texts[i], # Generated text
|
|
611
|
+
confidence=1.0, # Not applicable for generation
|
|
612
|
+
)
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
return predictions
|
|
616
|
+
|
|
617
|
+
def _do_predict_proba(
|
|
618
|
+
self, items: list[Item], participant_ids: list[str]
|
|
619
|
+
) -> np.ndarray:
|
|
620
|
+
"""Predict probabilities (not applicable for free text generation).
|
|
621
|
+
|
|
622
|
+
For text generation, returns empty array.
|
|
623
|
+
|
|
624
|
+
Parameters
|
|
625
|
+
----------
|
|
626
|
+
items : list[Item]
|
|
627
|
+
Items to predict.
|
|
628
|
+
participant_ids : list[str]
|
|
629
|
+
Participant identifiers.
|
|
630
|
+
|
|
631
|
+
Returns
|
|
632
|
+
-------
|
|
633
|
+
np.ndarray
|
|
634
|
+
Empty array of shape (n_items, 0).
|
|
635
|
+
"""
|
|
636
|
+
return np.zeros((len(items), 0))
|
|
637
|
+
|
|
638
|
+
def _compute_exact_match(self, predictions: list[str], labels: list[str]) -> float:
|
|
639
|
+
"""Compute exact match accuracy.
|
|
640
|
+
|
|
641
|
+
Parameters
|
|
642
|
+
----------
|
|
643
|
+
predictions : list[str]
|
|
644
|
+
Predicted texts.
|
|
645
|
+
labels : list[str]
|
|
646
|
+
Ground truth texts.
|
|
647
|
+
|
|
648
|
+
Returns
|
|
649
|
+
-------
|
|
650
|
+
float
|
|
651
|
+
Exact match accuracy (fraction of exact matches).
|
|
652
|
+
"""
|
|
653
|
+
return sum(
|
|
654
|
+
p.strip().lower() == label.strip().lower()
|
|
655
|
+
for p, label in zip(predictions, labels, strict=True)
|
|
656
|
+
) / len(predictions)
|
|
657
|
+
|
|
658
|
+
def _save_model_components(self, save_path: Path) -> None:
|
|
659
|
+
"""Save model-specific components (model, tokenizer).
|
|
660
|
+
|
|
661
|
+
Parameters
|
|
662
|
+
----------
|
|
663
|
+
save_path : Path
|
|
664
|
+
Directory path to save the model.
|
|
665
|
+
"""
|
|
666
|
+
self.model.save_pretrained(save_path / "model")
|
|
667
|
+
self.tokenizer.save_pretrained(save_path / "model")
|
|
668
|
+
|
|
669
|
+
def _load_model_components(self, load_path: Path) -> None:
|
|
670
|
+
"""Load model-specific components (model, tokenizer).
|
|
671
|
+
|
|
672
|
+
Parameters
|
|
673
|
+
----------
|
|
674
|
+
load_path : Path
|
|
675
|
+
Directory path to load the model from.
|
|
676
|
+
"""
|
|
677
|
+
# Load config.json to reconstruct config
|
|
678
|
+
with open(load_path / "config.json") as f:
|
|
679
|
+
config_dict = json.load(f)
|
|
680
|
+
|
|
681
|
+
# Reconstruct MixedEffectsConfig if needed
|
|
682
|
+
if "mixed_effects" in config_dict and isinstance(
|
|
683
|
+
config_dict["mixed_effects"], dict
|
|
684
|
+
):
|
|
685
|
+
from bead.active_learning.config import MixedEffectsConfig # noqa: PLC0415
|
|
686
|
+
|
|
687
|
+
config_dict["mixed_effects"] = MixedEffectsConfig(
|
|
688
|
+
**config_dict["mixed_effects"]
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
from bead.config.active_learning import FreeTextModelConfig # noqa: PLC0415
|
|
692
|
+
|
|
693
|
+
self.config = FreeTextModelConfig(**config_dict)
|
|
694
|
+
|
|
695
|
+
# Load model
|
|
696
|
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(load_path / "model")
|
|
697
|
+
self.tokenizer = AutoTokenizer.from_pretrained(load_path / "model")
|
|
698
|
+
|
|
699
|
+
# Re-extract components
|
|
700
|
+
self.encoder = self.model.get_encoder()
|
|
701
|
+
self.base_decoder = self.model.get_decoder()
|
|
702
|
+
self.lm_head = self.model.lm_head
|
|
703
|
+
|
|
704
|
+
self.model.to(self.config.device)
|
|
705
|
+
|
|
706
|
+
def _get_save_state(self) -> dict[str, object]:
|
|
707
|
+
"""Get model-specific state to save in config.json.
|
|
708
|
+
|
|
709
|
+
Returns
|
|
710
|
+
-------
|
|
711
|
+
dict[str, object]
|
|
712
|
+
Model-specific state dictionary.
|
|
713
|
+
"""
|
|
714
|
+
return {}
|
|
715
|
+
|
|
716
|
+
def _restore_training_state(self, config_dict: dict[str, object]) -> None:
|
|
717
|
+
"""Restore model-specific training state from config_dict.
|
|
718
|
+
|
|
719
|
+
Parameters
|
|
720
|
+
----------
|
|
721
|
+
config_dict : dict[str, object]
|
|
722
|
+
Configuration dictionary.
|
|
723
|
+
"""
|
|
724
|
+
pass
|
|
725
|
+
|
|
726
|
+
def _get_n_classes_for_random_effects(self) -> int:
|
|
727
|
+
"""Get the number of classes for initializing RandomEffectsManager.
|
|
728
|
+
|
|
729
|
+
For FreeTextModel, this is the vocabulary size.
|
|
730
|
+
|
|
731
|
+
Returns
|
|
732
|
+
-------
|
|
733
|
+
int
|
|
734
|
+
Vocabulary size.
|
|
735
|
+
"""
|
|
736
|
+
return self.lm_head.out_features
|
|
737
|
+
|
|
738
|
+
def _initialize_random_effects(self, n_classes: int, **kwargs: object) -> None:
|
|
739
|
+
"""Initialize the RandomEffectsManager.
|
|
740
|
+
|
|
741
|
+
Parameters
|
|
742
|
+
----------
|
|
743
|
+
n_classes : int
|
|
744
|
+
Vocabulary size (for FreeTextModel).
|
|
745
|
+
**kwargs : object
|
|
746
|
+
Additional keyword arguments (not used).
|
|
747
|
+
"""
|
|
748
|
+
self.random_effects = RandomEffectsManager(
|
|
749
|
+
self.config.mixed_effects,
|
|
750
|
+
vocab_size=n_classes,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
|
|
754
|
+
"""Get the fixed head for random effects.
|
|
755
|
+
|
|
756
|
+
For FreeTextModel with random_slopes, returns a template adapter.
|
|
757
|
+
For other modes, returns None.
|
|
758
|
+
|
|
759
|
+
Returns
|
|
760
|
+
-------
|
|
761
|
+
torch.nn.Module | None
|
|
762
|
+
Template adapter for random_slopes, None otherwise.
|
|
763
|
+
"""
|
|
764
|
+
if self.config.mixed_effects.mode == "random_slopes":
|
|
765
|
+
# For random_slopes, need to provide a template adapter
|
|
766
|
+
return create_participant_lora_adapter(
|
|
767
|
+
self.base_decoder,
|
|
768
|
+
rank=self.config.lora_rank,
|
|
769
|
+
alpha=self.config.lora_alpha,
|
|
770
|
+
dropout=self.config.lora_dropout,
|
|
771
|
+
target_modules=self.config.lora_target_modules,
|
|
772
|
+
)
|
|
773
|
+
return None
|