bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,551 @@
|
|
|
1
|
+
"""Mixed effects trainer for HuggingFace models.
|
|
2
|
+
|
|
3
|
+
This module provides a custom Trainer that handles participant-level
|
|
4
|
+
random effects (intercepts and slopes) while using HuggingFace Trainer
|
|
5
|
+
infrastructure for optimization, checkpointing, and device management.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn.functional
|
|
15
|
+
from transformers import Trainer, TrainingArguments
|
|
16
|
+
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
|
|
17
|
+
|
|
18
|
+
from bead.active_learning.models.random_effects import RandomEffectsManager
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from collections.abc import Mapping
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
from torch.utils.data import Dataset
|
|
25
|
+
from transformers import PreTrainedTokenizerBase
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MixedEffectsTrainer(Trainer):
|
|
29
|
+
"""HuggingFace Trainer with mixed effects support.
|
|
30
|
+
|
|
31
|
+
Extends HuggingFace Trainer to handle participant-level random effects
|
|
32
|
+
(random intercepts and random slopes) while using Trainer's
|
|
33
|
+
optimization, checkpointing, and device management.
|
|
34
|
+
|
|
35
|
+
The key innovation is overriding compute_loss to apply participant-specific
|
|
36
|
+
adjustments to model outputs before computing the loss. This preserves
|
|
37
|
+
the mixed effects functionality while using HuggingFace infrastructure.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
model : torch.nn.Module
|
|
42
|
+
The model to train (must support mixed effects).
|
|
43
|
+
args : TrainingArguments
|
|
44
|
+
HuggingFace training arguments.
|
|
45
|
+
train_dataset : Dataset
|
|
46
|
+
Training dataset (must include 'participant_id' field).
|
|
47
|
+
eval_dataset : Dataset | None
|
|
48
|
+
Evaluation dataset (optional).
|
|
49
|
+
random_effects_manager : RandomEffectsManager
|
|
50
|
+
Manager for participant-level random effects.
|
|
51
|
+
data_collator : Callable | None
|
|
52
|
+
Data collator (optional, uses default if None).
|
|
53
|
+
tokenizer : PreTrainedTokenizerBase | None
|
|
54
|
+
Tokenizer (optional, for data collation).
|
|
55
|
+
compute_metrics : Callable[[object], dict[str, float]] | None
|
|
56
|
+
Metrics computation function (optional).
|
|
57
|
+
|
|
58
|
+
Attributes
|
|
59
|
+
----------
|
|
60
|
+
random_effects_manager : RandomEffectsManager
|
|
61
|
+
Manager for random effects.
|
|
62
|
+
mixed_effects_config : MixedEffectsConfig
|
|
63
|
+
Mixed effects configuration.
|
|
64
|
+
|
|
65
|
+
Examples
|
|
66
|
+
--------
|
|
67
|
+
>>> from transformers import AutoModelForSequenceClassification, TrainingArguments
|
|
68
|
+
>>> from datasets import Dataset
|
|
69
|
+
>>> config = MixedEffectsConfig(mode='random_intercepts')
|
|
70
|
+
>>> manager = RandomEffectsManager(config, n_classes=2)
|
|
71
|
+
>>> model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
|
|
72
|
+
>>> trainer = MixedEffectsTrainer(
|
|
73
|
+
... model=model,
|
|
74
|
+
... args=TrainingArguments(output_dir='./output'),
|
|
75
|
+
... train_dataset=dataset,
|
|
76
|
+
... random_effects_manager=manager
|
|
77
|
+
... )
|
|
78
|
+
>>> trainer.train()
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
model: torch.nn.Module,
|
|
84
|
+
args: TrainingArguments,
|
|
85
|
+
train_dataset: Dataset,
|
|
86
|
+
random_effects_manager: RandomEffectsManager,
|
|
87
|
+
eval_dataset: Dataset | None = None,
|
|
88
|
+
data_collator: (
|
|
89
|
+
Callable[
|
|
90
|
+
[list[dict[str, torch.Tensor | str | int | float]]],
|
|
91
|
+
dict[str, torch.Tensor | list[str]],
|
|
92
|
+
]
|
|
93
|
+
| None
|
|
94
|
+
) = None,
|
|
95
|
+
tokenizer: PreTrainedTokenizerBase | None = None,
|
|
96
|
+
compute_metrics: Callable[[object], dict[str, float]] | None = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Initialize mixed effects trainer.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
model : torch.nn.Module
|
|
103
|
+
Model to train.
|
|
104
|
+
args : TrainingArguments
|
|
105
|
+
Training arguments.
|
|
106
|
+
train_dataset : Dataset
|
|
107
|
+
Training dataset.
|
|
108
|
+
random_effects_manager : RandomEffectsManager
|
|
109
|
+
Random effects manager.
|
|
110
|
+
eval_dataset : Dataset | None
|
|
111
|
+
Evaluation dataset.
|
|
112
|
+
data_collator : Callable | None
|
|
113
|
+
Data collator.
|
|
114
|
+
tokenizer : PreTrainedTokenizerBase | None
|
|
115
|
+
Tokenizer.
|
|
116
|
+
compute_metrics : Callable[[object], dict[str, float]] | None
|
|
117
|
+
Metrics computation function.
|
|
118
|
+
"""
|
|
119
|
+
super().__init__(
|
|
120
|
+
model=model,
|
|
121
|
+
args=args,
|
|
122
|
+
train_dataset=train_dataset,
|
|
123
|
+
eval_dataset=eval_dataset,
|
|
124
|
+
data_collator=data_collator,
|
|
125
|
+
processing_class=tokenizer,
|
|
126
|
+
compute_metrics=compute_metrics,
|
|
127
|
+
)
|
|
128
|
+
self.random_effects_manager = random_effects_manager
|
|
129
|
+
self.mixed_effects_config = random_effects_manager.config
|
|
130
|
+
|
|
131
|
+
def compute_loss(
|
|
132
|
+
self,
|
|
133
|
+
model: torch.nn.Module,
|
|
134
|
+
inputs: Mapping[str, torch.Tensor],
|
|
135
|
+
return_outputs: bool = False,
|
|
136
|
+
num_items_in_batch: int | None = None,
|
|
137
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
138
|
+
"""Compute loss with mixed effects adjustments.
|
|
139
|
+
|
|
140
|
+
Overrides HuggingFace Trainer's compute_loss to:
|
|
141
|
+
1. Get standard model outputs
|
|
142
|
+
2. Apply participant-specific adjustments (intercepts/slopes)
|
|
143
|
+
3. Compute loss with prior regularization
|
|
144
|
+
4. Return loss (and optionally outputs)
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
model : torch.nn.Module
|
|
149
|
+
Model to compute loss for.
|
|
150
|
+
inputs : Mapping[str, torch.Tensor]
|
|
151
|
+
Input batch (must include 'participant_id' if mixed effects).
|
|
152
|
+
participant_id should be a list[str] in the dataset, but will be
|
|
153
|
+
converted to tensor by data collator.
|
|
154
|
+
return_outputs : bool
|
|
155
|
+
Whether to return model outputs.
|
|
156
|
+
|
|
157
|
+
Returns
|
|
158
|
+
-------
|
|
159
|
+
torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
|
160
|
+
Loss tensor, or (loss, outputs) if return_outputs=True.
|
|
161
|
+
"""
|
|
162
|
+
# Get labels and participant IDs
|
|
163
|
+
labels = inputs.get("labels")
|
|
164
|
+
participant_ids = inputs.get("participant_id")
|
|
165
|
+
|
|
166
|
+
# For random_slopes mode, pass participant_id to model (wrapper handles routing)
|
|
167
|
+
# For other modes, remove participant_id from inputs
|
|
168
|
+
if self.mixed_effects_config.mode == "random_slopes":
|
|
169
|
+
# RandomSlopesModelWrapper expects participant_id in forward()
|
|
170
|
+
model_inputs = {k: v for k, v in inputs.items() if k != "labels"}
|
|
171
|
+
else:
|
|
172
|
+
excluded = ("labels", "participant_id")
|
|
173
|
+
model_inputs = {k: v for k, v in inputs.items() if k not in excluded}
|
|
174
|
+
|
|
175
|
+
# Standard forward pass
|
|
176
|
+
outputs = model(**model_inputs)
|
|
177
|
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs
|
|
178
|
+
|
|
179
|
+
# Apply mixed effects adjustments
|
|
180
|
+
if self.mixed_effects_config.mode == "random_intercepts":
|
|
181
|
+
# Apply participant-specific biases to logits
|
|
182
|
+
if participant_ids is not None:
|
|
183
|
+
batch_size = logits.shape[0]
|
|
184
|
+
# Handle participant_ids: could be tensor of indices or list of strings
|
|
185
|
+
# In our case, we store participant_ids as strings in dataset
|
|
186
|
+
# The data collator will need to handle this specially
|
|
187
|
+
for i in range(batch_size):
|
|
188
|
+
# Extract participant ID - data collator provides as list[str]
|
|
189
|
+
if isinstance(participant_ids, list):
|
|
190
|
+
pid_str = str(participant_ids[i])
|
|
191
|
+
elif isinstance(participant_ids, torch.Tensor):
|
|
192
|
+
# Fallback: if somehow tensor, convert
|
|
193
|
+
pid_elem = participant_ids[i]
|
|
194
|
+
pid_raw = pid_elem.item() if pid_elem.numel() == 1 else pid_elem
|
|
195
|
+
pid_str = str(pid_raw)
|
|
196
|
+
else:
|
|
197
|
+
pid_str = str(participant_ids[i])
|
|
198
|
+
|
|
199
|
+
# Get bias for this participant
|
|
200
|
+
# For binary: n_classes=1 (scalar bias)
|
|
201
|
+
n_classes = logits.shape[1] if logits.dim() > 1 else 1
|
|
202
|
+
bias = self.random_effects_manager.get_intercepts(
|
|
203
|
+
pid_str,
|
|
204
|
+
n_classes=n_classes,
|
|
205
|
+
param_name="mu",
|
|
206
|
+
create_if_missing=True,
|
|
207
|
+
)
|
|
208
|
+
# Ensure bias is on same device as logits
|
|
209
|
+
bias = bias.to(logits.device)
|
|
210
|
+
# For binary, bias is scalar, add to logits
|
|
211
|
+
if logits.dim() == 1:
|
|
212
|
+
bias_val = bias[0] if bias.numel() > 0 else 0
|
|
213
|
+
logits[i] = logits[i] + bias_val
|
|
214
|
+
else:
|
|
215
|
+
logits[i] = logits[i] + bias
|
|
216
|
+
|
|
217
|
+
elif self.mixed_effects_config.mode == "random_slopes":
|
|
218
|
+
# Random slopes are handled by RandomSlopesModelWrapper in forward()
|
|
219
|
+
# The model routes each sample through participant-specific heads
|
|
220
|
+
# Logits already incorporate random slopes - nothing to do here
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
# Compute data loss
|
|
224
|
+
if labels is not None:
|
|
225
|
+
# Check if this is regression (continuous labels) or classification
|
|
226
|
+
# Regression: labels are float, logits shape is (batch, 1)
|
|
227
|
+
# Classification: labels are int/long, logits shape varies by task
|
|
228
|
+
if labels.dtype.is_floating_point:
|
|
229
|
+
# Regression task: use MSE loss
|
|
230
|
+
if logits.dim() == 2 and logits.shape[1] == 1:
|
|
231
|
+
# Squeeze to (batch,)
|
|
232
|
+
preds = logits.squeeze(1)
|
|
233
|
+
elif logits.dim() == 1:
|
|
234
|
+
preds = logits
|
|
235
|
+
else:
|
|
236
|
+
# Unexpected shape, use first column
|
|
237
|
+
preds = logits[:, 0]
|
|
238
|
+
loss = torch.nn.functional.mse_loss(preds, labels.float())
|
|
239
|
+
elif logits.dim() == 1 or (logits.dim() == 2 and logits.shape[1] == 1):
|
|
240
|
+
# Binary classification
|
|
241
|
+
loss_fct = torch.nn.functional.binary_cross_entropy_with_logits
|
|
242
|
+
if labels.dim() == 0:
|
|
243
|
+
labels = labels.unsqueeze(0)
|
|
244
|
+
if logits.dim() == 1:
|
|
245
|
+
logits = logits.unsqueeze(1)
|
|
246
|
+
loss = loss_fct(logits.squeeze(1), labels.float())
|
|
247
|
+
else:
|
|
248
|
+
# Multi-class classification
|
|
249
|
+
loss_fct = torch.nn.functional.cross_entropy
|
|
250
|
+
loss = loss_fct(logits, labels.long())
|
|
251
|
+
else:
|
|
252
|
+
# No labels provided (unsupervised)
|
|
253
|
+
loss = torch.tensor(0.0, device=logits.device)
|
|
254
|
+
|
|
255
|
+
# Add prior regularization loss
|
|
256
|
+
loss_prior = self.random_effects_manager.compute_prior_loss()
|
|
257
|
+
if loss_prior.device != loss.device:
|
|
258
|
+
loss_prior = loss_prior.to(loss.device)
|
|
259
|
+
loss = loss + loss_prior
|
|
260
|
+
|
|
261
|
+
if return_outputs:
|
|
262
|
+
# Create output object with adjusted logits
|
|
263
|
+
adjusted_outputs = SequenceClassifierOutput(logits=logits)
|
|
264
|
+
return (loss, adjusted_outputs)
|
|
265
|
+
return loss
|
|
266
|
+
|
|
267
|
+
def create_optimizer(self) -> None:
|
|
268
|
+
"""Create optimizer with all parameters including participant heads.
|
|
269
|
+
|
|
270
|
+
For random_slopes mode, this method collects parameters from:
|
|
271
|
+
1. The encoder (via model.encoder or model.model.encoder)
|
|
272
|
+
2. The fixed classifier head
|
|
273
|
+
3. All participant-specific heads (slopes)
|
|
274
|
+
|
|
275
|
+
For other modes, delegates to parent implementation.
|
|
276
|
+
"""
|
|
277
|
+
if self.optimizer is not None:
|
|
278
|
+
# Optimizer already exists
|
|
279
|
+
return
|
|
280
|
+
|
|
281
|
+
if self.mixed_effects_config.mode == "random_slopes":
|
|
282
|
+
# Collect parameters for random_slopes mode
|
|
283
|
+
optimizer_grouped_parameters: list[dict[str, object]] = []
|
|
284
|
+
|
|
285
|
+
# Check if model has get_all_parameters method (RandomSlopesModelWrapper)
|
|
286
|
+
if hasattr(self.model, "get_all_parameters"):
|
|
287
|
+
all_params = self.model.get_all_parameters()
|
|
288
|
+
optimizer_grouped_parameters.append(
|
|
289
|
+
{
|
|
290
|
+
"params": all_params,
|
|
291
|
+
"lr": self.args.learning_rate,
|
|
292
|
+
}
|
|
293
|
+
)
|
|
294
|
+
else:
|
|
295
|
+
# Fallback: collect standard model parameters plus slope parameters
|
|
296
|
+
optimizer_grouped_parameters.append(
|
|
297
|
+
{
|
|
298
|
+
"params": list(self.model.parameters()),
|
|
299
|
+
"lr": self.args.learning_rate,
|
|
300
|
+
}
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Add participant head parameters from random_effects_manager
|
|
304
|
+
if hasattr(self.random_effects_manager, "slopes"):
|
|
305
|
+
for head in self.random_effects_manager.slopes.values():
|
|
306
|
+
if hasattr(head, "parameters"):
|
|
307
|
+
optimizer_grouped_parameters.append(
|
|
308
|
+
{
|
|
309
|
+
"params": list(head.parameters()),
|
|
310
|
+
"lr": self.args.learning_rate,
|
|
311
|
+
}
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Create AdamW optimizer
|
|
315
|
+
self.optimizer = torch.optim.AdamW(
|
|
316
|
+
optimizer_grouped_parameters,
|
|
317
|
+
lr=self.args.learning_rate,
|
|
318
|
+
weight_decay=self.args.weight_decay,
|
|
319
|
+
)
|
|
320
|
+
else:
|
|
321
|
+
# Use parent implementation for other modes
|
|
322
|
+
super().create_optimizer()
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class ClozeMLMTrainer(MixedEffectsTrainer):
|
|
326
|
+
"""Custom trainer for cloze (MLM) tasks with custom masking positions.
|
|
327
|
+
|
|
328
|
+
Extends MixedEffectsTrainer to handle MLM loss computation only on
|
|
329
|
+
specific masked positions (from unfilled_slots) rather than all positions.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
model : torch.nn.Module
|
|
334
|
+
MLM model (AutoModelForMaskedLM or wrapper).
|
|
335
|
+
args : TrainingArguments
|
|
336
|
+
Training arguments.
|
|
337
|
+
train_dataset : Dataset
|
|
338
|
+
Training dataset (must include 'masked_positions' and 'target_token_ids').
|
|
339
|
+
random_effects_manager : RandomEffectsManager
|
|
340
|
+
Random effects manager.
|
|
341
|
+
eval_dataset : Dataset | None
|
|
342
|
+
Evaluation dataset.
|
|
343
|
+
data_collator : Callable | None
|
|
344
|
+
Data collator (should be ClozeDataCollator).
|
|
345
|
+
tokenizer : PreTrainedTokenizerBase | None
|
|
346
|
+
Tokenizer.
|
|
347
|
+
compute_metrics : Callable[[object], dict[str, float]] | None
|
|
348
|
+
Metrics computation function.
|
|
349
|
+
|
|
350
|
+
Examples
|
|
351
|
+
--------
|
|
352
|
+
>>> from transformers import AutoModelForMaskedLM, TrainingArguments
|
|
353
|
+
>>> model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')
|
|
354
|
+
>>> trainer = ClozeMLMTrainer(
|
|
355
|
+
... model=model,
|
|
356
|
+
... args=TrainingArguments(output_dir='./output'),
|
|
357
|
+
... train_dataset=dataset,
|
|
358
|
+
... random_effects_manager=manager
|
|
359
|
+
... )
|
|
360
|
+
>>> trainer.train()
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
def compute_loss(
|
|
364
|
+
self,
|
|
365
|
+
model: torch.nn.Module,
|
|
366
|
+
inputs: Mapping[str, torch.Tensor | list[str] | list[list[int]]],
|
|
367
|
+
return_outputs: bool = False,
|
|
368
|
+
num_items_in_batch: int | None = None,
|
|
369
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
370
|
+
"""Compute MLM loss only on masked positions.
|
|
371
|
+
|
|
372
|
+
Overrides MixedEffectsTrainer's compute_loss to:
|
|
373
|
+
1. Get model outputs (logits for all positions)
|
|
374
|
+
2. Apply participant-specific adjustments (intercepts) if needed
|
|
375
|
+
3. Compute cross-entropy loss only on masked positions
|
|
376
|
+
4. Add prior regularization
|
|
377
|
+
5. Return loss (and optionally outputs)
|
|
378
|
+
|
|
379
|
+
Parameters
|
|
380
|
+
----------
|
|
381
|
+
model : torch.nn.Module
|
|
382
|
+
Model to compute loss for.
|
|
383
|
+
inputs : Mapping[str, torch.Tensor | list[str] | list[list[int]]]
|
|
384
|
+
Input batch with:
|
|
385
|
+
- Standard tokenized inputs (input_ids, attention_mask, etc.)
|
|
386
|
+
- participant_id: list[str]
|
|
387
|
+
- masked_positions: list[list[int]] - masked token positions per item
|
|
388
|
+
- target_token_ids: list[list[int]] - target token IDs per masked position
|
|
389
|
+
return_outputs : bool
|
|
390
|
+
Whether to return model outputs.
|
|
391
|
+
num_items_in_batch : int | None
|
|
392
|
+
Unused, kept for compatibility.
|
|
393
|
+
|
|
394
|
+
Returns
|
|
395
|
+
-------
|
|
396
|
+
torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
|
397
|
+
Loss tensor, or (loss, outputs) if return_outputs=True.
|
|
398
|
+
"""
|
|
399
|
+
# Extract cloze-specific fields
|
|
400
|
+
participant_ids = inputs.get("participant_id")
|
|
401
|
+
masked_positions = inputs.get("masked_positions", [])
|
|
402
|
+
target_token_ids = inputs.get("target_token_ids", [])
|
|
403
|
+
|
|
404
|
+
# Remove these from inputs for model forward pass
|
|
405
|
+
excluded = ("labels", "participant_id", "masked_positions", "target_token_ids")
|
|
406
|
+
model_inputs = {k: v for k, v in inputs.items() if k not in excluded}
|
|
407
|
+
|
|
408
|
+
# Standard forward pass
|
|
409
|
+
outputs = model(**model_inputs)
|
|
410
|
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs
|
|
411
|
+
# logits shape: (batch, seq_len, vocab_size)
|
|
412
|
+
|
|
413
|
+
# Apply mixed effects adjustments for random_intercepts
|
|
414
|
+
if self.mixed_effects_config.mode == "random_intercepts":
|
|
415
|
+
if participant_ids is not None and isinstance(participant_ids, list):
|
|
416
|
+
vocab_size = logits.shape[2]
|
|
417
|
+
batch_size = logits.shape[0]
|
|
418
|
+
for i in range(batch_size):
|
|
419
|
+
pid_str = str(participant_ids[i])
|
|
420
|
+
# Get bias for this participant (vocab_size,)
|
|
421
|
+
bias = self.random_effects_manager.get_intercepts(
|
|
422
|
+
pid_str,
|
|
423
|
+
n_classes=vocab_size,
|
|
424
|
+
param_name="mu",
|
|
425
|
+
create_if_missing=True,
|
|
426
|
+
)
|
|
427
|
+
bias = bias.to(logits.device)
|
|
428
|
+
# Add bias to all masked positions for this item
|
|
429
|
+
in_range = i < len(masked_positions)
|
|
430
|
+
if in_range and isinstance(masked_positions[i], list):
|
|
431
|
+
for pos in masked_positions[i]:
|
|
432
|
+
if pos < logits.shape[1]:
|
|
433
|
+
logits[i, pos] = logits[i, pos] + bias
|
|
434
|
+
|
|
435
|
+
# Compute loss only on masked positions
|
|
436
|
+
losses: list[torch.Tensor] = []
|
|
437
|
+
if isinstance(masked_positions, list) and isinstance(target_token_ids, list):
|
|
438
|
+
for j, (masked_pos, target_ids) in enumerate(
|
|
439
|
+
zip(masked_positions, target_token_ids, strict=True)
|
|
440
|
+
):
|
|
441
|
+
if j >= logits.shape[0]:
|
|
442
|
+
continue
|
|
443
|
+
if isinstance(masked_pos, list) and isinstance(target_ids, list):
|
|
444
|
+
for pos, target_id in zip(masked_pos, target_ids, strict=True):
|
|
445
|
+
if pos < logits.shape[1]:
|
|
446
|
+
# Cross-entropy loss for this position
|
|
447
|
+
# logits[j, pos] shape: (vocab_size,)
|
|
448
|
+
# target_id: int
|
|
449
|
+
# Need shape (1, vocab_size) for logits and (1,) for target
|
|
450
|
+
pos_logits = logits[j, pos].unsqueeze(0) # (1, vocab_size)
|
|
451
|
+
pos_target = torch.tensor(
|
|
452
|
+
[target_id], device=logits.device, dtype=torch.long
|
|
453
|
+
) # (1,)
|
|
454
|
+
loss_j = torch.nn.functional.cross_entropy(
|
|
455
|
+
pos_logits, pos_target
|
|
456
|
+
)
|
|
457
|
+
losses.append(loss_j)
|
|
458
|
+
|
|
459
|
+
if losses:
|
|
460
|
+
loss_nll = torch.stack(losses).mean()
|
|
461
|
+
else:
|
|
462
|
+
loss_nll = torch.tensor(0.0, device=logits.device)
|
|
463
|
+
|
|
464
|
+
# Add prior regularization loss
|
|
465
|
+
loss_prior = self.random_effects_manager.compute_prior_loss()
|
|
466
|
+
if loss_prior.device != loss_nll.device:
|
|
467
|
+
loss_prior = loss_prior.to(loss_nll.device)
|
|
468
|
+
loss = loss_nll + loss_prior
|
|
469
|
+
|
|
470
|
+
if return_outputs:
|
|
471
|
+
# Return outputs with logits
|
|
472
|
+
adjusted_outputs = MaskedLMOutput(logits=logits)
|
|
473
|
+
return (loss, adjusted_outputs)
|
|
474
|
+
return loss
|
|
475
|
+
|
|
476
|
+
def prediction_step(
|
|
477
|
+
self,
|
|
478
|
+
model: torch.nn.Module,
|
|
479
|
+
inputs: dict[str, torch.Tensor | list[str] | list[list[int]]],
|
|
480
|
+
prediction_loss_only: bool,
|
|
481
|
+
ignore_keys: list[str] | None = None,
|
|
482
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
|
483
|
+
"""Perform a prediction step with cloze-specific label encoding.
|
|
484
|
+
|
|
485
|
+
Creates labels tensor encoding target_token_ids at masked_positions
|
|
486
|
+
with -100 elsewhere (HuggingFace ignore index convention). This enables
|
|
487
|
+
compute_cloze_metrics() to evaluate predictions at the correct positions.
|
|
488
|
+
|
|
489
|
+
Parameters
|
|
490
|
+
----------
|
|
491
|
+
model : torch.nn.Module
|
|
492
|
+
Model to use for prediction.
|
|
493
|
+
inputs : dict[str, torch.Tensor | list[str] | list[list[int]]]
|
|
494
|
+
Input batch with:
|
|
495
|
+
- Standard tokenized inputs (input_ids, attention_mask, etc.)
|
|
496
|
+
- participant_id: list[str]
|
|
497
|
+
- masked_positions: list[list[int]] - masked token positions per item
|
|
498
|
+
- target_token_ids: list[list[int]] - target token IDs per position
|
|
499
|
+
prediction_loss_only : bool
|
|
500
|
+
Whether to only return loss.
|
|
501
|
+
ignore_keys : list[str] | None
|
|
502
|
+
Keys to ignore (unused).
|
|
503
|
+
|
|
504
|
+
Returns
|
|
505
|
+
-------
|
|
506
|
+
tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]
|
|
507
|
+
(loss, logits, labels) tuple where labels encodes target tokens
|
|
508
|
+
at masked positions with -100 elsewhere.
|
|
509
|
+
"""
|
|
510
|
+
# Extract cloze-specific fields
|
|
511
|
+
masked_positions = inputs.get("masked_positions", [])
|
|
512
|
+
target_token_ids = inputs.get("target_token_ids", [])
|
|
513
|
+
|
|
514
|
+
# Filter inputs for model forward pass
|
|
515
|
+
model_inputs = {
|
|
516
|
+
k: v
|
|
517
|
+
for k, v in inputs.items()
|
|
518
|
+
if k not in ("participant_id", "masked_positions", "target_token_ids")
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
# Get predictions from parent (which handles compute_loss internally)
|
|
522
|
+
loss, logits, _ = super().prediction_step(
|
|
523
|
+
model, model_inputs, prediction_loss_only, ignore_keys
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
if prediction_loss_only:
|
|
527
|
+
return (loss, None, None)
|
|
528
|
+
|
|
529
|
+
# Build labels tensor: (batch_size, seq_len) with -100 default
|
|
530
|
+
labels = None
|
|
531
|
+
has_masks = isinstance(masked_positions, list)
|
|
532
|
+
has_targets = isinstance(target_token_ids, list)
|
|
533
|
+
if logits is not None and has_masks and has_targets:
|
|
534
|
+
batch_size, seq_len = logits.shape[:2]
|
|
535
|
+
labels = torch.full(
|
|
536
|
+
(batch_size, seq_len), -100, dtype=torch.long, device=logits.device
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# Fill in target token IDs at masked positions
|
|
540
|
+
for i, (positions, targets) in enumerate(
|
|
541
|
+
zip(masked_positions, target_token_ids, strict=False)
|
|
542
|
+
):
|
|
543
|
+
if i >= batch_size:
|
|
544
|
+
break
|
|
545
|
+
if isinstance(positions, list) and isinstance(targets, list):
|
|
546
|
+
for pos, target_id in zip(positions, targets, strict=False):
|
|
547
|
+
if isinstance(pos, int) and isinstance(target_id, int):
|
|
548
|
+
if 0 <= pos < seq_len:
|
|
549
|
+
labels[i, pos] = target_id
|
|
550
|
+
|
|
551
|
+
return (loss, logits, labels)
|