bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
"""Model wrapper for HuggingFace Trainer integration.
|
|
2
|
+
|
|
3
|
+
This module provides wrapper models that combine encoder and classifier
|
|
4
|
+
head into a single model compatible with HuggingFace Trainer.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from transformers import PreTrainedModel
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class EncoderClassifierWrapper(nn.Module):
|
|
20
|
+
"""Wrapper that combines encoder and classifier for HuggingFace Trainer.
|
|
21
|
+
|
|
22
|
+
This wrapper takes a transformer encoder and a classifier head and
|
|
23
|
+
combines them into a single model that HuggingFace Trainer can use.
|
|
24
|
+
The forward method takes standard HuggingFace inputs (input_ids, etc.)
|
|
25
|
+
and returns outputs with .logits attribute.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
encoder : PreTrainedModel
|
|
30
|
+
Transformer encoder (e.g., BERT, RoBERTa).
|
|
31
|
+
classifier_head : nn.Module
|
|
32
|
+
Classification head that takes encoder outputs.
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
encoder : PreTrainedModel
|
|
37
|
+
Transformer encoder.
|
|
38
|
+
classifier_head : nn.Module
|
|
39
|
+
Classification head.
|
|
40
|
+
|
|
41
|
+
Examples
|
|
42
|
+
--------
|
|
43
|
+
>>> from transformers import AutoModel, AutoModelForSequenceClassification
|
|
44
|
+
>>> encoder = AutoModel.from_pretrained('bert-base-uncased')
|
|
45
|
+
>>> classifier = nn.Linear(768, 1) # Binary classification
|
|
46
|
+
>>> model = EncoderClassifierWrapper(encoder, classifier)
|
|
47
|
+
>>> outputs = model(input_ids=..., attention_mask=...)
|
|
48
|
+
>>> logits = outputs.logits
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
encoder: PreTrainedModel,
|
|
54
|
+
classifier_head: nn.Module,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Initialize wrapper.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
encoder : PreTrainedModel
|
|
61
|
+
Transformer encoder.
|
|
62
|
+
classifier_head : nn.Module
|
|
63
|
+
Classification head.
|
|
64
|
+
"""
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.encoder = encoder
|
|
67
|
+
self.classifier_head = classifier_head
|
|
68
|
+
|
|
69
|
+
def forward(
|
|
70
|
+
self,
|
|
71
|
+
input_ids: torch.Tensor | None = None,
|
|
72
|
+
attention_mask: torch.Tensor | None = None,
|
|
73
|
+
token_type_ids: torch.Tensor | None = None,
|
|
74
|
+
**kwargs: torch.Tensor,
|
|
75
|
+
) -> SequenceClassifierOutput:
|
|
76
|
+
"""Forward pass through encoder and classifier.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
input_ids : torch.Tensor | None
|
|
81
|
+
Token IDs.
|
|
82
|
+
attention_mask : torch.Tensor | None
|
|
83
|
+
Attention mask.
|
|
84
|
+
token_type_ids : torch.Tensor | None
|
|
85
|
+
Token type IDs (for BERT-style models).
|
|
86
|
+
**kwargs : torch.Tensor
|
|
87
|
+
Additional model inputs.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
SequenceClassifierOutput
|
|
92
|
+
Outputs with .logits attribute (for HuggingFace compatibility).
|
|
93
|
+
"""
|
|
94
|
+
# Encoder forward pass
|
|
95
|
+
encoder_inputs: dict[str, torch.Tensor] = {}
|
|
96
|
+
if input_ids is not None:
|
|
97
|
+
encoder_inputs["input_ids"] = input_ids
|
|
98
|
+
if attention_mask is not None:
|
|
99
|
+
encoder_inputs["attention_mask"] = attention_mask
|
|
100
|
+
if token_type_ids is not None:
|
|
101
|
+
encoder_inputs["token_type_ids"] = token_type_ids
|
|
102
|
+
|
|
103
|
+
# Add any other kwargs that encoder might accept
|
|
104
|
+
for key, value in kwargs.items():
|
|
105
|
+
if key not in ("labels", "participant_id"):
|
|
106
|
+
encoder_inputs[key] = value
|
|
107
|
+
|
|
108
|
+
encoder_outputs = self.encoder(**encoder_inputs)
|
|
109
|
+
|
|
110
|
+
# Extract [CLS] token representation (first token)
|
|
111
|
+
# Shape: (batch_size, hidden_size)
|
|
112
|
+
if hasattr(encoder_outputs, "last_hidden_state"):
|
|
113
|
+
cls_embedding = encoder_outputs.last_hidden_state[:, 0, :]
|
|
114
|
+
elif hasattr(encoder_outputs, "pooler_output"):
|
|
115
|
+
cls_embedding = encoder_outputs.pooler_output
|
|
116
|
+
else:
|
|
117
|
+
# Fallback: use first token from sequence
|
|
118
|
+
cls_embedding = encoder_outputs[0][:, 0, :]
|
|
119
|
+
|
|
120
|
+
# Classifier forward pass
|
|
121
|
+
logits = self.classifier_head(cls_embedding)
|
|
122
|
+
|
|
123
|
+
# Return SequenceClassifierOutput for HuggingFace compatibility
|
|
124
|
+
# This is the standard output format that Trainer expects
|
|
125
|
+
return SequenceClassifierOutput(logits=logits)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class EncoderRegressionWrapper(nn.Module):
|
|
129
|
+
"""Wrapper that combines encoder and regression head for HuggingFace Trainer.
|
|
130
|
+
|
|
131
|
+
This wrapper takes a transformer encoder and a regression head and
|
|
132
|
+
combines them into a single model that HuggingFace Trainer can use.
|
|
133
|
+
The forward method takes standard HuggingFace inputs (input_ids, etc.)
|
|
134
|
+
and returns outputs with .logits attribute (for regression, logits
|
|
135
|
+
represents continuous values).
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
encoder : PreTrainedModel
|
|
140
|
+
Transformer encoder (e.g., BERT, RoBERTa).
|
|
141
|
+
regression_head : nn.Module
|
|
142
|
+
Regression head that takes encoder outputs and outputs continuous values.
|
|
143
|
+
|
|
144
|
+
Attributes
|
|
145
|
+
----------
|
|
146
|
+
encoder : PreTrainedModel
|
|
147
|
+
Transformer encoder.
|
|
148
|
+
regression_head : nn.Module
|
|
149
|
+
Regression head.
|
|
150
|
+
|
|
151
|
+
Examples
|
|
152
|
+
--------
|
|
153
|
+
>>> from transformers import AutoModel
|
|
154
|
+
>>> encoder = AutoModel.from_pretrained('bert-base-uncased')
|
|
155
|
+
>>> regressor = nn.Linear(768, 1) # Single continuous output
|
|
156
|
+
>>> model = EncoderRegressionWrapper(encoder, regressor)
|
|
157
|
+
>>> outputs = model(input_ids=..., attention_mask=...)
|
|
158
|
+
>>> predictions = outputs.logits.squeeze() # Continuous values
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
encoder: PreTrainedModel,
|
|
164
|
+
regression_head: nn.Module,
|
|
165
|
+
) -> None:
|
|
166
|
+
"""Initialize wrapper.
|
|
167
|
+
|
|
168
|
+
Parameters
|
|
169
|
+
----------
|
|
170
|
+
encoder : PreTrainedModel
|
|
171
|
+
Transformer encoder.
|
|
172
|
+
regression_head : nn.Module
|
|
173
|
+
Regression head.
|
|
174
|
+
"""
|
|
175
|
+
super().__init__()
|
|
176
|
+
self.encoder = encoder
|
|
177
|
+
self.regression_head = regression_head
|
|
178
|
+
|
|
179
|
+
def forward(
|
|
180
|
+
self,
|
|
181
|
+
input_ids: torch.Tensor | None = None,
|
|
182
|
+
attention_mask: torch.Tensor | None = None,
|
|
183
|
+
token_type_ids: torch.Tensor | None = None,
|
|
184
|
+
**kwargs: torch.Tensor,
|
|
185
|
+
) -> SequenceClassifierOutput:
|
|
186
|
+
"""Forward pass through encoder and regression head.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
input_ids : torch.Tensor | None
|
|
191
|
+
Token IDs.
|
|
192
|
+
attention_mask : torch.Tensor | None
|
|
193
|
+
Attention mask.
|
|
194
|
+
token_type_ids : torch.Tensor | None
|
|
195
|
+
Token type IDs (for BERT-style models).
|
|
196
|
+
**kwargs : torch.Tensor
|
|
197
|
+
Additional model inputs.
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
SequenceClassifierOutput
|
|
202
|
+
Outputs with .logits attribute containing continuous values.
|
|
203
|
+
"""
|
|
204
|
+
# Encoder forward pass
|
|
205
|
+
encoder_inputs: dict[str, torch.Tensor] = {}
|
|
206
|
+
if input_ids is not None:
|
|
207
|
+
encoder_inputs["input_ids"] = input_ids
|
|
208
|
+
if attention_mask is not None:
|
|
209
|
+
encoder_inputs["attention_mask"] = attention_mask
|
|
210
|
+
if token_type_ids is not None:
|
|
211
|
+
encoder_inputs["token_type_ids"] = token_type_ids
|
|
212
|
+
|
|
213
|
+
# Add any other kwargs that encoder might accept
|
|
214
|
+
for key, value in kwargs.items():
|
|
215
|
+
if key not in ("labels", "participant_id"):
|
|
216
|
+
encoder_inputs[key] = value
|
|
217
|
+
|
|
218
|
+
encoder_outputs = self.encoder(**encoder_inputs)
|
|
219
|
+
|
|
220
|
+
# Extract [CLS] token representation (first token)
|
|
221
|
+
if hasattr(encoder_outputs, "last_hidden_state"):
|
|
222
|
+
cls_embedding = encoder_outputs.last_hidden_state[:, 0, :]
|
|
223
|
+
elif hasattr(encoder_outputs, "pooler_output"):
|
|
224
|
+
cls_embedding = encoder_outputs.pooler_output
|
|
225
|
+
else:
|
|
226
|
+
# Fallback: use first token from sequence
|
|
227
|
+
cls_embedding = encoder_outputs[0][:, 0, :]
|
|
228
|
+
|
|
229
|
+
# Regression head forward pass
|
|
230
|
+
# Output shape: (batch_size, 1) for single continuous value
|
|
231
|
+
logits = self.regression_head(cls_embedding)
|
|
232
|
+
|
|
233
|
+
# Return SequenceClassifierOutput for HuggingFace compatibility
|
|
234
|
+
# For regression, logits represents continuous values
|
|
235
|
+
return SequenceClassifierOutput(logits=logits)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class MLMModelWrapper(nn.Module):
|
|
239
|
+
"""Wrapper for MLM models to work with HuggingFace Trainer.
|
|
240
|
+
|
|
241
|
+
This wrapper takes an AutoModelForMaskedLM and makes it compatible
|
|
242
|
+
with the Trainer while allowing access to encoder and mlm_head separately
|
|
243
|
+
for mixed effects adjustments.
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
model : PreTrainedModel
|
|
248
|
+
AutoModelForMaskedLM model.
|
|
249
|
+
|
|
250
|
+
Attributes
|
|
251
|
+
----------
|
|
252
|
+
model : PreTrainedModel
|
|
253
|
+
The MLM model.
|
|
254
|
+
encoder : nn.Module
|
|
255
|
+
Encoder module (extracted from model).
|
|
256
|
+
mlm_head : nn.Module
|
|
257
|
+
MLM head (extracted from model).
|
|
258
|
+
|
|
259
|
+
Examples
|
|
260
|
+
--------
|
|
261
|
+
>>> from transformers import AutoModelForMaskedLM
|
|
262
|
+
>>> model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')
|
|
263
|
+
>>> wrapped = MLMModelWrapper(model)
|
|
264
|
+
>>> outputs = wrapped(input_ids=..., attention_mask=...)
|
|
265
|
+
>>> logits = outputs.logits # (batch, seq_len, vocab_size)
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
def __init__(self, model: PreTrainedModel) -> None:
|
|
269
|
+
"""Initialize wrapper.
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
model : PreTrainedModel
|
|
274
|
+
AutoModelForMaskedLM model.
|
|
275
|
+
"""
|
|
276
|
+
super().__init__()
|
|
277
|
+
self.model = model
|
|
278
|
+
|
|
279
|
+
# Extract encoder and MLM head
|
|
280
|
+
if hasattr(model, "bert"):
|
|
281
|
+
self.encoder = model.bert
|
|
282
|
+
self.mlm_head = model.cls
|
|
283
|
+
elif hasattr(model, "roberta"):
|
|
284
|
+
self.encoder = model.roberta
|
|
285
|
+
self.mlm_head = model.lm_head
|
|
286
|
+
else:
|
|
287
|
+
# Fallback: try base_model and lm_head
|
|
288
|
+
self.encoder = model.base_model
|
|
289
|
+
self.mlm_head = model.lm_head
|
|
290
|
+
|
|
291
|
+
def forward(
|
|
292
|
+
self,
|
|
293
|
+
input_ids: torch.Tensor | None = None,
|
|
294
|
+
attention_mask: torch.Tensor | None = None,
|
|
295
|
+
token_type_ids: torch.Tensor | None = None,
|
|
296
|
+
**kwargs: torch.Tensor,
|
|
297
|
+
) -> MaskedLMOutput:
|
|
298
|
+
"""Forward pass through MLM model.
|
|
299
|
+
|
|
300
|
+
Parameters
|
|
301
|
+
----------
|
|
302
|
+
input_ids : torch.Tensor | None
|
|
303
|
+
Token IDs.
|
|
304
|
+
attention_mask : torch.Tensor | None
|
|
305
|
+
Attention mask.
|
|
306
|
+
token_type_ids : torch.Tensor | None
|
|
307
|
+
Token type IDs (for BERT-style models).
|
|
308
|
+
**kwargs : torch.Tensor
|
|
309
|
+
Additional model inputs.
|
|
310
|
+
|
|
311
|
+
Returns
|
|
312
|
+
-------
|
|
313
|
+
MaskedLMOutput
|
|
314
|
+
Model outputs with .logits attribute (shape: batch, seq_len, vocab_size).
|
|
315
|
+
"""
|
|
316
|
+
# Forward through full model
|
|
317
|
+
encoder_inputs: dict[str, torch.Tensor] = {}
|
|
318
|
+
if input_ids is not None:
|
|
319
|
+
encoder_inputs["input_ids"] = input_ids
|
|
320
|
+
if attention_mask is not None:
|
|
321
|
+
encoder_inputs["attention_mask"] = attention_mask
|
|
322
|
+
if token_type_ids is not None:
|
|
323
|
+
encoder_inputs["token_type_ids"] = token_type_ids
|
|
324
|
+
|
|
325
|
+
# Add any other kwargs that model might accept
|
|
326
|
+
for key, value in kwargs.items():
|
|
327
|
+
if key not in (
|
|
328
|
+
"labels",
|
|
329
|
+
"participant_id",
|
|
330
|
+
"masked_positions",
|
|
331
|
+
"target_token_ids",
|
|
332
|
+
):
|
|
333
|
+
encoder_inputs[key] = value
|
|
334
|
+
|
|
335
|
+
# Use the full model's forward pass
|
|
336
|
+
outputs = self.model(**encoder_inputs)
|
|
337
|
+
return outputs
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class RandomSlopesModelWrapper(nn.Module):
|
|
341
|
+
"""Wrapper for random slopes with per-participant classifier heads.
|
|
342
|
+
|
|
343
|
+
This wrapper combines:
|
|
344
|
+
- A shared encoder (transformer backbone)
|
|
345
|
+
- A fixed classifier head (population-level)
|
|
346
|
+
- Per-participant heads via RandomEffectsManager
|
|
347
|
+
|
|
348
|
+
During forward pass, each sample is routed through its participant's
|
|
349
|
+
specific classifier head. New participant heads are created on-demand
|
|
350
|
+
by cloning the fixed head.
|
|
351
|
+
|
|
352
|
+
Parameters
|
|
353
|
+
----------
|
|
354
|
+
encoder : PreTrainedModel
|
|
355
|
+
Transformer encoder (e.g., BERT, RoBERTa).
|
|
356
|
+
classifier_head : nn.Module
|
|
357
|
+
Fixed/population-level classification head.
|
|
358
|
+
random_effects_manager : object
|
|
359
|
+
RandomEffectsManager instance that stores participant slopes.
|
|
360
|
+
|
|
361
|
+
Attributes
|
|
362
|
+
----------
|
|
363
|
+
encoder : PreTrainedModel
|
|
364
|
+
Transformer encoder.
|
|
365
|
+
classifier_head : nn.Module
|
|
366
|
+
Fixed classification head (used as template for new participants).
|
|
367
|
+
random_effects_manager : object
|
|
368
|
+
Manager for participant-specific heads.
|
|
369
|
+
|
|
370
|
+
Examples
|
|
371
|
+
--------
|
|
372
|
+
>>> from transformers import AutoModel
|
|
373
|
+
>>> from bead.active_learning.models.random_effects import RandomEffectsManager
|
|
374
|
+
>>> encoder = AutoModel.from_pretrained('bert-base-uncased')
|
|
375
|
+
>>> classifier = nn.Linear(768, 2) # Binary classification
|
|
376
|
+
>>> manager = RandomEffectsManager(config, n_classes=2)
|
|
377
|
+
>>> model = RandomSlopesModelWrapper(encoder, classifier, manager)
|
|
378
|
+
>>> outputs = model(input_ids=..., attention_mask=..., participant_id=['p1', 'p2'])
|
|
379
|
+
>>> logits = outputs.logits
|
|
380
|
+
"""
|
|
381
|
+
|
|
382
|
+
def __init__(
|
|
383
|
+
self,
|
|
384
|
+
encoder: PreTrainedModel,
|
|
385
|
+
classifier_head: nn.Module,
|
|
386
|
+
random_effects_manager: object,
|
|
387
|
+
) -> None:
|
|
388
|
+
"""Initialize wrapper.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
encoder : PreTrainedModel
|
|
393
|
+
Transformer encoder.
|
|
394
|
+
classifier_head : nn.Module
|
|
395
|
+
Fixed classification head.
|
|
396
|
+
random_effects_manager : object
|
|
397
|
+
RandomEffectsManager for participant heads.
|
|
398
|
+
"""
|
|
399
|
+
super().__init__()
|
|
400
|
+
self.encoder = encoder
|
|
401
|
+
self.classifier_head = classifier_head
|
|
402
|
+
self.random_effects_manager = random_effects_manager
|
|
403
|
+
|
|
404
|
+
def forward(
|
|
405
|
+
self,
|
|
406
|
+
input_ids: torch.Tensor | None = None,
|
|
407
|
+
attention_mask: torch.Tensor | None = None,
|
|
408
|
+
token_type_ids: torch.Tensor | None = None,
|
|
409
|
+
participant_id: list[str] | None = None,
|
|
410
|
+
**kwargs: torch.Tensor,
|
|
411
|
+
) -> SequenceClassifierOutput:
|
|
412
|
+
"""Forward pass through encoder and participant-specific heads.
|
|
413
|
+
|
|
414
|
+
Each sample is routed through its participant's classifier head.
|
|
415
|
+
If participant_id is None, uses the fixed (population) head.
|
|
416
|
+
|
|
417
|
+
Parameters
|
|
418
|
+
----------
|
|
419
|
+
input_ids : torch.Tensor | None
|
|
420
|
+
Token IDs.
|
|
421
|
+
attention_mask : torch.Tensor | None
|
|
422
|
+
Attention mask.
|
|
423
|
+
token_type_ids : torch.Tensor | None
|
|
424
|
+
Token type IDs (for BERT-style models).
|
|
425
|
+
participant_id : list[str] | None
|
|
426
|
+
List of participant IDs for each sample in the batch.
|
|
427
|
+
If None, uses fixed head for all samples.
|
|
428
|
+
**kwargs : torch.Tensor
|
|
429
|
+
Additional model inputs.
|
|
430
|
+
|
|
431
|
+
Returns
|
|
432
|
+
-------
|
|
433
|
+
SequenceClassifierOutput
|
|
434
|
+
Outputs with .logits attribute (for HuggingFace compatibility).
|
|
435
|
+
"""
|
|
436
|
+
# Encoder forward pass
|
|
437
|
+
encoder_inputs: dict[str, torch.Tensor] = {}
|
|
438
|
+
if input_ids is not None:
|
|
439
|
+
encoder_inputs["input_ids"] = input_ids
|
|
440
|
+
if attention_mask is not None:
|
|
441
|
+
encoder_inputs["attention_mask"] = attention_mask
|
|
442
|
+
if token_type_ids is not None:
|
|
443
|
+
encoder_inputs["token_type_ids"] = token_type_ids
|
|
444
|
+
|
|
445
|
+
# Add any other kwargs that encoder might accept
|
|
446
|
+
for key, value in kwargs.items():
|
|
447
|
+
if key not in ("labels", "participant_id"):
|
|
448
|
+
encoder_inputs[key] = value
|
|
449
|
+
|
|
450
|
+
encoder_outputs = self.encoder(**encoder_inputs)
|
|
451
|
+
|
|
452
|
+
# Extract [CLS] token representation (first token)
|
|
453
|
+
if hasattr(encoder_outputs, "last_hidden_state"):
|
|
454
|
+
cls_embedding = encoder_outputs.last_hidden_state[:, 0, :]
|
|
455
|
+
elif hasattr(encoder_outputs, "pooler_output"):
|
|
456
|
+
cls_embedding = encoder_outputs.pooler_output
|
|
457
|
+
else:
|
|
458
|
+
# Fallback: use first token from sequence
|
|
459
|
+
cls_embedding = encoder_outputs[0][:, 0, :]
|
|
460
|
+
|
|
461
|
+
# Route through participant-specific heads
|
|
462
|
+
if participant_id is None:
|
|
463
|
+
# No participant IDs - use fixed head for all
|
|
464
|
+
logits = self.classifier_head(cls_embedding)
|
|
465
|
+
else:
|
|
466
|
+
# Per-participant routing
|
|
467
|
+
logits_list: list[torch.Tensor] = []
|
|
468
|
+
for i, pid in enumerate(participant_id):
|
|
469
|
+
# Get or create participant-specific head
|
|
470
|
+
participant_head = self.random_effects_manager.get_slopes(
|
|
471
|
+
pid,
|
|
472
|
+
fixed_head=self.classifier_head,
|
|
473
|
+
create_if_missing=True,
|
|
474
|
+
)
|
|
475
|
+
# Forward single sample through participant's head
|
|
476
|
+
sample_embedding = cls_embedding[i : i + 1] # Keep batch dimension
|
|
477
|
+
sample_logits = participant_head(sample_embedding)
|
|
478
|
+
logits_list.append(sample_logits)
|
|
479
|
+
|
|
480
|
+
# Concatenate all logits
|
|
481
|
+
logits = torch.cat(logits_list, dim=0)
|
|
482
|
+
|
|
483
|
+
# Return SequenceClassifierOutput for HuggingFace compatibility
|
|
484
|
+
return SequenceClassifierOutput(logits=logits)
|
|
485
|
+
|
|
486
|
+
def get_all_parameters(self) -> list[nn.Parameter]:
|
|
487
|
+
"""Get all parameters including dynamically created participant heads.
|
|
488
|
+
|
|
489
|
+
This method collects parameters from:
|
|
490
|
+
1. The encoder
|
|
491
|
+
2. The fixed classifier head
|
|
492
|
+
3. All participant-specific heads (slopes)
|
|
493
|
+
|
|
494
|
+
Returns
|
|
495
|
+
-------
|
|
496
|
+
list[nn.Parameter]
|
|
497
|
+
List of all model parameters.
|
|
498
|
+
"""
|
|
499
|
+
params: list[nn.Parameter] = []
|
|
500
|
+
params.extend(self.encoder.parameters())
|
|
501
|
+
params.extend(self.classifier_head.parameters())
|
|
502
|
+
|
|
503
|
+
# Add participant head parameters if available
|
|
504
|
+
if hasattr(self.random_effects_manager, "slopes"):
|
|
505
|
+
for head in self.random_effects_manager.slopes.values():
|
|
506
|
+
if hasattr(head, "parameters"):
|
|
507
|
+
params.extend(head.parameters())
|
|
508
|
+
|
|
509
|
+
return params
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Trainer registry for framework selection.
|
|
2
|
+
|
|
3
|
+
This module provides a registry for managing different trainer implementations,
|
|
4
|
+
allowing users to select trainers by name.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from bead.active_learning.trainers.huggingface import HuggingFaceTrainer
|
|
12
|
+
from bead.active_learning.trainers.lightning import PyTorchLightningTrainer
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from bead.active_learning.trainers.base import BaseTrainer
|
|
16
|
+
|
|
17
|
+
_TRAINERS: dict[str, type[BaseTrainer]] = {}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def register_trainer(name: str, trainer_class: type[BaseTrainer]) -> None:
|
|
21
|
+
"""Register a trainer class.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
name : str
|
|
26
|
+
Trainer name (e.g., "huggingface", "pytorch_lightning").
|
|
27
|
+
trainer_class : type[BaseTrainer]
|
|
28
|
+
Trainer class to register.
|
|
29
|
+
|
|
30
|
+
Examples
|
|
31
|
+
--------
|
|
32
|
+
>>> from bead.active_learning.trainers.base import BaseTrainer
|
|
33
|
+
>>> class MyTrainer(BaseTrainer): # doctest: +SKIP
|
|
34
|
+
... def train(self, train_data, eval_data=None):
|
|
35
|
+
... pass
|
|
36
|
+
... def save_model(self, output_dir, metadata):
|
|
37
|
+
... pass
|
|
38
|
+
... def load_model(self, model_dir):
|
|
39
|
+
... pass
|
|
40
|
+
>>> register_trainer("my_trainer", MyTrainer) # doctest: +SKIP
|
|
41
|
+
>>> "my_trainer" in list_trainers() # doctest: +SKIP
|
|
42
|
+
True
|
|
43
|
+
"""
|
|
44
|
+
_TRAINERS[name] = trainer_class
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_trainer(name: str) -> type[BaseTrainer]:
|
|
48
|
+
"""Get trainer class by name.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
name : str
|
|
53
|
+
Trainer name.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
type[BaseTrainer]
|
|
58
|
+
Trainer class.
|
|
59
|
+
|
|
60
|
+
Raises
|
|
61
|
+
------
|
|
62
|
+
ValueError
|
|
63
|
+
If trainer name is not registered.
|
|
64
|
+
|
|
65
|
+
Examples
|
|
66
|
+
--------
|
|
67
|
+
>>> trainer_class = get_trainer("huggingface")
|
|
68
|
+
>>> trainer_class.__name__
|
|
69
|
+
'HuggingFaceTrainer'
|
|
70
|
+
>>> get_trainer("unknown") # doctest: +SKIP
|
|
71
|
+
Traceback (most recent call last):
|
|
72
|
+
...
|
|
73
|
+
ValueError: Unknown trainer: unknown. Available trainers: huggingface,
|
|
74
|
+
pytorch_lightning
|
|
75
|
+
"""
|
|
76
|
+
if name not in _TRAINERS:
|
|
77
|
+
available = ", ".join(list_trainers())
|
|
78
|
+
msg = f"Unknown trainer: {name}. Available trainers: {available}"
|
|
79
|
+
raise ValueError(msg)
|
|
80
|
+
return _TRAINERS[name]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def list_trainers() -> list[str]:
|
|
84
|
+
"""List available trainers.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
list[str]
|
|
89
|
+
List of registered trainer names.
|
|
90
|
+
|
|
91
|
+
Examples
|
|
92
|
+
--------
|
|
93
|
+
>>> trainers = list_trainers()
|
|
94
|
+
>>> "huggingface" in trainers
|
|
95
|
+
True
|
|
96
|
+
>>> "pytorch_lightning" in trainers
|
|
97
|
+
True
|
|
98
|
+
"""
|
|
99
|
+
return list(_TRAINERS.keys())
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# Register built-in trainers
|
|
103
|
+
register_trainer("huggingface", HuggingFaceTrainer)
|
|
104
|
+
register_trainer("pytorch_lightning", PyTorchLightningTrainer)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Shared adapter utilities.
|
|
2
|
+
|
|
3
|
+
Provides base classes and utilities for integrating with external ML
|
|
4
|
+
frameworks like HuggingFace Transformers.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from bead.adapters.huggingface import HuggingFaceAdapterMixin
|
|
10
|
+
|
|
11
|
+
__all__ = ["HuggingFaceAdapterMixin"]
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Shared utilities for HuggingFace Transformers adapters.
|
|
2
|
+
|
|
3
|
+
This module provides common functionality for adapters that integrate with
|
|
4
|
+
HuggingFace Transformers models, including device validation and shared
|
|
5
|
+
utilities.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Literal
|
|
12
|
+
|
|
13
|
+
import torch.backends.mps
|
|
14
|
+
import torch.cuda
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
DeviceType = Literal["cpu", "cuda", "mps"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _cuda_available() -> bool:
|
|
22
|
+
"""Check if CUDA is available."""
|
|
23
|
+
return torch.cuda.is_available() # pyright: ignore[reportAttributeAccessIssue]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _mps_available() -> bool:
|
|
27
|
+
"""Check if MPS (Apple Silicon) is available."""
|
|
28
|
+
return torch.backends.mps.is_available() # pyright: ignore[reportAttributeAccessIssue]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class HuggingFaceAdapterMixin:
|
|
32
|
+
"""Mixin providing common HuggingFace adapter functionality.
|
|
33
|
+
|
|
34
|
+
This mixin provides device validation with automatic fallback.
|
|
35
|
+
|
|
36
|
+
Attributes
|
|
37
|
+
----------
|
|
38
|
+
device : DeviceType
|
|
39
|
+
The validated device (cpu, cuda, or mps).
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def _validate_device(self, device: DeviceType) -> DeviceType:
|
|
43
|
+
"""Validate device and fallback if unavailable.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
device : DeviceType
|
|
48
|
+
Requested device.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
DeviceType
|
|
53
|
+
Validated device (falls back to CPU if unavailable).
|
|
54
|
+
"""
|
|
55
|
+
if device == "cuda" and not _cuda_available():
|
|
56
|
+
logger.warning("CUDA unavailable, using CPU")
|
|
57
|
+
return "cpu"
|
|
58
|
+
if device == "mps" and not _mps_available():
|
|
59
|
+
logger.warning("MPS unavailable, using CPU")
|
|
60
|
+
return "cpu"
|
|
61
|
+
return device
|