bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,509 @@
1
+ """Model wrapper for HuggingFace Trainer integration.
2
+
3
+ This module provides wrapper models that combine encoder and classifier
4
+ head into a single model compatible with HuggingFace Trainer.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
14
+
15
+ if TYPE_CHECKING:
16
+ from transformers import PreTrainedModel
17
+
18
+
19
+ class EncoderClassifierWrapper(nn.Module):
20
+ """Wrapper that combines encoder and classifier for HuggingFace Trainer.
21
+
22
+ This wrapper takes a transformer encoder and a classifier head and
23
+ combines them into a single model that HuggingFace Trainer can use.
24
+ The forward method takes standard HuggingFace inputs (input_ids, etc.)
25
+ and returns outputs with .logits attribute.
26
+
27
+ Parameters
28
+ ----------
29
+ encoder : PreTrainedModel
30
+ Transformer encoder (e.g., BERT, RoBERTa).
31
+ classifier_head : nn.Module
32
+ Classification head that takes encoder outputs.
33
+
34
+ Attributes
35
+ ----------
36
+ encoder : PreTrainedModel
37
+ Transformer encoder.
38
+ classifier_head : nn.Module
39
+ Classification head.
40
+
41
+ Examples
42
+ --------
43
+ >>> from transformers import AutoModel, AutoModelForSequenceClassification
44
+ >>> encoder = AutoModel.from_pretrained('bert-base-uncased')
45
+ >>> classifier = nn.Linear(768, 1) # Binary classification
46
+ >>> model = EncoderClassifierWrapper(encoder, classifier)
47
+ >>> outputs = model(input_ids=..., attention_mask=...)
48
+ >>> logits = outputs.logits
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ encoder: PreTrainedModel,
54
+ classifier_head: nn.Module,
55
+ ) -> None:
56
+ """Initialize wrapper.
57
+
58
+ Parameters
59
+ ----------
60
+ encoder : PreTrainedModel
61
+ Transformer encoder.
62
+ classifier_head : nn.Module
63
+ Classification head.
64
+ """
65
+ super().__init__()
66
+ self.encoder = encoder
67
+ self.classifier_head = classifier_head
68
+
69
+ def forward(
70
+ self,
71
+ input_ids: torch.Tensor | None = None,
72
+ attention_mask: torch.Tensor | None = None,
73
+ token_type_ids: torch.Tensor | None = None,
74
+ **kwargs: torch.Tensor,
75
+ ) -> SequenceClassifierOutput:
76
+ """Forward pass through encoder and classifier.
77
+
78
+ Parameters
79
+ ----------
80
+ input_ids : torch.Tensor | None
81
+ Token IDs.
82
+ attention_mask : torch.Tensor | None
83
+ Attention mask.
84
+ token_type_ids : torch.Tensor | None
85
+ Token type IDs (for BERT-style models).
86
+ **kwargs : torch.Tensor
87
+ Additional model inputs.
88
+
89
+ Returns
90
+ -------
91
+ SequenceClassifierOutput
92
+ Outputs with .logits attribute (for HuggingFace compatibility).
93
+ """
94
+ # Encoder forward pass
95
+ encoder_inputs: dict[str, torch.Tensor] = {}
96
+ if input_ids is not None:
97
+ encoder_inputs["input_ids"] = input_ids
98
+ if attention_mask is not None:
99
+ encoder_inputs["attention_mask"] = attention_mask
100
+ if token_type_ids is not None:
101
+ encoder_inputs["token_type_ids"] = token_type_ids
102
+
103
+ # Add any other kwargs that encoder might accept
104
+ for key, value in kwargs.items():
105
+ if key not in ("labels", "participant_id"):
106
+ encoder_inputs[key] = value
107
+
108
+ encoder_outputs = self.encoder(**encoder_inputs)
109
+
110
+ # Extract [CLS] token representation (first token)
111
+ # Shape: (batch_size, hidden_size)
112
+ if hasattr(encoder_outputs, "last_hidden_state"):
113
+ cls_embedding = encoder_outputs.last_hidden_state[:, 0, :]
114
+ elif hasattr(encoder_outputs, "pooler_output"):
115
+ cls_embedding = encoder_outputs.pooler_output
116
+ else:
117
+ # Fallback: use first token from sequence
118
+ cls_embedding = encoder_outputs[0][:, 0, :]
119
+
120
+ # Classifier forward pass
121
+ logits = self.classifier_head(cls_embedding)
122
+
123
+ # Return SequenceClassifierOutput for HuggingFace compatibility
124
+ # This is the standard output format that Trainer expects
125
+ return SequenceClassifierOutput(logits=logits)
126
+
127
+
128
+ class EncoderRegressionWrapper(nn.Module):
129
+ """Wrapper that combines encoder and regression head for HuggingFace Trainer.
130
+
131
+ This wrapper takes a transformer encoder and a regression head and
132
+ combines them into a single model that HuggingFace Trainer can use.
133
+ The forward method takes standard HuggingFace inputs (input_ids, etc.)
134
+ and returns outputs with .logits attribute (for regression, logits
135
+ represents continuous values).
136
+
137
+ Parameters
138
+ ----------
139
+ encoder : PreTrainedModel
140
+ Transformer encoder (e.g., BERT, RoBERTa).
141
+ regression_head : nn.Module
142
+ Regression head that takes encoder outputs and outputs continuous values.
143
+
144
+ Attributes
145
+ ----------
146
+ encoder : PreTrainedModel
147
+ Transformer encoder.
148
+ regression_head : nn.Module
149
+ Regression head.
150
+
151
+ Examples
152
+ --------
153
+ >>> from transformers import AutoModel
154
+ >>> encoder = AutoModel.from_pretrained('bert-base-uncased')
155
+ >>> regressor = nn.Linear(768, 1) # Single continuous output
156
+ >>> model = EncoderRegressionWrapper(encoder, regressor)
157
+ >>> outputs = model(input_ids=..., attention_mask=...)
158
+ >>> predictions = outputs.logits.squeeze() # Continuous values
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ encoder: PreTrainedModel,
164
+ regression_head: nn.Module,
165
+ ) -> None:
166
+ """Initialize wrapper.
167
+
168
+ Parameters
169
+ ----------
170
+ encoder : PreTrainedModel
171
+ Transformer encoder.
172
+ regression_head : nn.Module
173
+ Regression head.
174
+ """
175
+ super().__init__()
176
+ self.encoder = encoder
177
+ self.regression_head = regression_head
178
+
179
+ def forward(
180
+ self,
181
+ input_ids: torch.Tensor | None = None,
182
+ attention_mask: torch.Tensor | None = None,
183
+ token_type_ids: torch.Tensor | None = None,
184
+ **kwargs: torch.Tensor,
185
+ ) -> SequenceClassifierOutput:
186
+ """Forward pass through encoder and regression head.
187
+
188
+ Parameters
189
+ ----------
190
+ input_ids : torch.Tensor | None
191
+ Token IDs.
192
+ attention_mask : torch.Tensor | None
193
+ Attention mask.
194
+ token_type_ids : torch.Tensor | None
195
+ Token type IDs (for BERT-style models).
196
+ **kwargs : torch.Tensor
197
+ Additional model inputs.
198
+
199
+ Returns
200
+ -------
201
+ SequenceClassifierOutput
202
+ Outputs with .logits attribute containing continuous values.
203
+ """
204
+ # Encoder forward pass
205
+ encoder_inputs: dict[str, torch.Tensor] = {}
206
+ if input_ids is not None:
207
+ encoder_inputs["input_ids"] = input_ids
208
+ if attention_mask is not None:
209
+ encoder_inputs["attention_mask"] = attention_mask
210
+ if token_type_ids is not None:
211
+ encoder_inputs["token_type_ids"] = token_type_ids
212
+
213
+ # Add any other kwargs that encoder might accept
214
+ for key, value in kwargs.items():
215
+ if key not in ("labels", "participant_id"):
216
+ encoder_inputs[key] = value
217
+
218
+ encoder_outputs = self.encoder(**encoder_inputs)
219
+
220
+ # Extract [CLS] token representation (first token)
221
+ if hasattr(encoder_outputs, "last_hidden_state"):
222
+ cls_embedding = encoder_outputs.last_hidden_state[:, 0, :]
223
+ elif hasattr(encoder_outputs, "pooler_output"):
224
+ cls_embedding = encoder_outputs.pooler_output
225
+ else:
226
+ # Fallback: use first token from sequence
227
+ cls_embedding = encoder_outputs[0][:, 0, :]
228
+
229
+ # Regression head forward pass
230
+ # Output shape: (batch_size, 1) for single continuous value
231
+ logits = self.regression_head(cls_embedding)
232
+
233
+ # Return SequenceClassifierOutput for HuggingFace compatibility
234
+ # For regression, logits represents continuous values
235
+ return SequenceClassifierOutput(logits=logits)
236
+
237
+
238
+ class MLMModelWrapper(nn.Module):
239
+ """Wrapper for MLM models to work with HuggingFace Trainer.
240
+
241
+ This wrapper takes an AutoModelForMaskedLM and makes it compatible
242
+ with the Trainer while allowing access to encoder and mlm_head separately
243
+ for mixed effects adjustments.
244
+
245
+ Parameters
246
+ ----------
247
+ model : PreTrainedModel
248
+ AutoModelForMaskedLM model.
249
+
250
+ Attributes
251
+ ----------
252
+ model : PreTrainedModel
253
+ The MLM model.
254
+ encoder : nn.Module
255
+ Encoder module (extracted from model).
256
+ mlm_head : nn.Module
257
+ MLM head (extracted from model).
258
+
259
+ Examples
260
+ --------
261
+ >>> from transformers import AutoModelForMaskedLM
262
+ >>> model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')
263
+ >>> wrapped = MLMModelWrapper(model)
264
+ >>> outputs = wrapped(input_ids=..., attention_mask=...)
265
+ >>> logits = outputs.logits # (batch, seq_len, vocab_size)
266
+ """
267
+
268
+ def __init__(self, model: PreTrainedModel) -> None:
269
+ """Initialize wrapper.
270
+
271
+ Parameters
272
+ ----------
273
+ model : PreTrainedModel
274
+ AutoModelForMaskedLM model.
275
+ """
276
+ super().__init__()
277
+ self.model = model
278
+
279
+ # Extract encoder and MLM head
280
+ if hasattr(model, "bert"):
281
+ self.encoder = model.bert
282
+ self.mlm_head = model.cls
283
+ elif hasattr(model, "roberta"):
284
+ self.encoder = model.roberta
285
+ self.mlm_head = model.lm_head
286
+ else:
287
+ # Fallback: try base_model and lm_head
288
+ self.encoder = model.base_model
289
+ self.mlm_head = model.lm_head
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: torch.Tensor | None = None,
294
+ attention_mask: torch.Tensor | None = None,
295
+ token_type_ids: torch.Tensor | None = None,
296
+ **kwargs: torch.Tensor,
297
+ ) -> MaskedLMOutput:
298
+ """Forward pass through MLM model.
299
+
300
+ Parameters
301
+ ----------
302
+ input_ids : torch.Tensor | None
303
+ Token IDs.
304
+ attention_mask : torch.Tensor | None
305
+ Attention mask.
306
+ token_type_ids : torch.Tensor | None
307
+ Token type IDs (for BERT-style models).
308
+ **kwargs : torch.Tensor
309
+ Additional model inputs.
310
+
311
+ Returns
312
+ -------
313
+ MaskedLMOutput
314
+ Model outputs with .logits attribute (shape: batch, seq_len, vocab_size).
315
+ """
316
+ # Forward through full model
317
+ encoder_inputs: dict[str, torch.Tensor] = {}
318
+ if input_ids is not None:
319
+ encoder_inputs["input_ids"] = input_ids
320
+ if attention_mask is not None:
321
+ encoder_inputs["attention_mask"] = attention_mask
322
+ if token_type_ids is not None:
323
+ encoder_inputs["token_type_ids"] = token_type_ids
324
+
325
+ # Add any other kwargs that model might accept
326
+ for key, value in kwargs.items():
327
+ if key not in (
328
+ "labels",
329
+ "participant_id",
330
+ "masked_positions",
331
+ "target_token_ids",
332
+ ):
333
+ encoder_inputs[key] = value
334
+
335
+ # Use the full model's forward pass
336
+ outputs = self.model(**encoder_inputs)
337
+ return outputs
338
+
339
+
340
+ class RandomSlopesModelWrapper(nn.Module):
341
+ """Wrapper for random slopes with per-participant classifier heads.
342
+
343
+ This wrapper combines:
344
+ - A shared encoder (transformer backbone)
345
+ - A fixed classifier head (population-level)
346
+ - Per-participant heads via RandomEffectsManager
347
+
348
+ During forward pass, each sample is routed through its participant's
349
+ specific classifier head. New participant heads are created on-demand
350
+ by cloning the fixed head.
351
+
352
+ Parameters
353
+ ----------
354
+ encoder : PreTrainedModel
355
+ Transformer encoder (e.g., BERT, RoBERTa).
356
+ classifier_head : nn.Module
357
+ Fixed/population-level classification head.
358
+ random_effects_manager : object
359
+ RandomEffectsManager instance that stores participant slopes.
360
+
361
+ Attributes
362
+ ----------
363
+ encoder : PreTrainedModel
364
+ Transformer encoder.
365
+ classifier_head : nn.Module
366
+ Fixed classification head (used as template for new participants).
367
+ random_effects_manager : object
368
+ Manager for participant-specific heads.
369
+
370
+ Examples
371
+ --------
372
+ >>> from transformers import AutoModel
373
+ >>> from bead.active_learning.models.random_effects import RandomEffectsManager
374
+ >>> encoder = AutoModel.from_pretrained('bert-base-uncased')
375
+ >>> classifier = nn.Linear(768, 2) # Binary classification
376
+ >>> manager = RandomEffectsManager(config, n_classes=2)
377
+ >>> model = RandomSlopesModelWrapper(encoder, classifier, manager)
378
+ >>> outputs = model(input_ids=..., attention_mask=..., participant_id=['p1', 'p2'])
379
+ >>> logits = outputs.logits
380
+ """
381
+
382
+ def __init__(
383
+ self,
384
+ encoder: PreTrainedModel,
385
+ classifier_head: nn.Module,
386
+ random_effects_manager: object,
387
+ ) -> None:
388
+ """Initialize wrapper.
389
+
390
+ Parameters
391
+ ----------
392
+ encoder : PreTrainedModel
393
+ Transformer encoder.
394
+ classifier_head : nn.Module
395
+ Fixed classification head.
396
+ random_effects_manager : object
397
+ RandomEffectsManager for participant heads.
398
+ """
399
+ super().__init__()
400
+ self.encoder = encoder
401
+ self.classifier_head = classifier_head
402
+ self.random_effects_manager = random_effects_manager
403
+
404
+ def forward(
405
+ self,
406
+ input_ids: torch.Tensor | None = None,
407
+ attention_mask: torch.Tensor | None = None,
408
+ token_type_ids: torch.Tensor | None = None,
409
+ participant_id: list[str] | None = None,
410
+ **kwargs: torch.Tensor,
411
+ ) -> SequenceClassifierOutput:
412
+ """Forward pass through encoder and participant-specific heads.
413
+
414
+ Each sample is routed through its participant's classifier head.
415
+ If participant_id is None, uses the fixed (population) head.
416
+
417
+ Parameters
418
+ ----------
419
+ input_ids : torch.Tensor | None
420
+ Token IDs.
421
+ attention_mask : torch.Tensor | None
422
+ Attention mask.
423
+ token_type_ids : torch.Tensor | None
424
+ Token type IDs (for BERT-style models).
425
+ participant_id : list[str] | None
426
+ List of participant IDs for each sample in the batch.
427
+ If None, uses fixed head for all samples.
428
+ **kwargs : torch.Tensor
429
+ Additional model inputs.
430
+
431
+ Returns
432
+ -------
433
+ SequenceClassifierOutput
434
+ Outputs with .logits attribute (for HuggingFace compatibility).
435
+ """
436
+ # Encoder forward pass
437
+ encoder_inputs: dict[str, torch.Tensor] = {}
438
+ if input_ids is not None:
439
+ encoder_inputs["input_ids"] = input_ids
440
+ if attention_mask is not None:
441
+ encoder_inputs["attention_mask"] = attention_mask
442
+ if token_type_ids is not None:
443
+ encoder_inputs["token_type_ids"] = token_type_ids
444
+
445
+ # Add any other kwargs that encoder might accept
446
+ for key, value in kwargs.items():
447
+ if key not in ("labels", "participant_id"):
448
+ encoder_inputs[key] = value
449
+
450
+ encoder_outputs = self.encoder(**encoder_inputs)
451
+
452
+ # Extract [CLS] token representation (first token)
453
+ if hasattr(encoder_outputs, "last_hidden_state"):
454
+ cls_embedding = encoder_outputs.last_hidden_state[:, 0, :]
455
+ elif hasattr(encoder_outputs, "pooler_output"):
456
+ cls_embedding = encoder_outputs.pooler_output
457
+ else:
458
+ # Fallback: use first token from sequence
459
+ cls_embedding = encoder_outputs[0][:, 0, :]
460
+
461
+ # Route through participant-specific heads
462
+ if participant_id is None:
463
+ # No participant IDs - use fixed head for all
464
+ logits = self.classifier_head(cls_embedding)
465
+ else:
466
+ # Per-participant routing
467
+ logits_list: list[torch.Tensor] = []
468
+ for i, pid in enumerate(participant_id):
469
+ # Get or create participant-specific head
470
+ participant_head = self.random_effects_manager.get_slopes(
471
+ pid,
472
+ fixed_head=self.classifier_head,
473
+ create_if_missing=True,
474
+ )
475
+ # Forward single sample through participant's head
476
+ sample_embedding = cls_embedding[i : i + 1] # Keep batch dimension
477
+ sample_logits = participant_head(sample_embedding)
478
+ logits_list.append(sample_logits)
479
+
480
+ # Concatenate all logits
481
+ logits = torch.cat(logits_list, dim=0)
482
+
483
+ # Return SequenceClassifierOutput for HuggingFace compatibility
484
+ return SequenceClassifierOutput(logits=logits)
485
+
486
+ def get_all_parameters(self) -> list[nn.Parameter]:
487
+ """Get all parameters including dynamically created participant heads.
488
+
489
+ This method collects parameters from:
490
+ 1. The encoder
491
+ 2. The fixed classifier head
492
+ 3. All participant-specific heads (slopes)
493
+
494
+ Returns
495
+ -------
496
+ list[nn.Parameter]
497
+ List of all model parameters.
498
+ """
499
+ params: list[nn.Parameter] = []
500
+ params.extend(self.encoder.parameters())
501
+ params.extend(self.classifier_head.parameters())
502
+
503
+ # Add participant head parameters if available
504
+ if hasattr(self.random_effects_manager, "slopes"):
505
+ for head in self.random_effects_manager.slopes.values():
506
+ if hasattr(head, "parameters"):
507
+ params.extend(head.parameters())
508
+
509
+ return params
@@ -0,0 +1,104 @@
1
+ """Trainer registry for framework selection.
2
+
3
+ This module provides a registry for managing different trainer implementations,
4
+ allowing users to select trainers by name.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ from bead.active_learning.trainers.huggingface import HuggingFaceTrainer
12
+ from bead.active_learning.trainers.lightning import PyTorchLightningTrainer
13
+
14
+ if TYPE_CHECKING:
15
+ from bead.active_learning.trainers.base import BaseTrainer
16
+
17
+ _TRAINERS: dict[str, type[BaseTrainer]] = {}
18
+
19
+
20
+ def register_trainer(name: str, trainer_class: type[BaseTrainer]) -> None:
21
+ """Register a trainer class.
22
+
23
+ Parameters
24
+ ----------
25
+ name : str
26
+ Trainer name (e.g., "huggingface", "pytorch_lightning").
27
+ trainer_class : type[BaseTrainer]
28
+ Trainer class to register.
29
+
30
+ Examples
31
+ --------
32
+ >>> from bead.active_learning.trainers.base import BaseTrainer
33
+ >>> class MyTrainer(BaseTrainer): # doctest: +SKIP
34
+ ... def train(self, train_data, eval_data=None):
35
+ ... pass
36
+ ... def save_model(self, output_dir, metadata):
37
+ ... pass
38
+ ... def load_model(self, model_dir):
39
+ ... pass
40
+ >>> register_trainer("my_trainer", MyTrainer) # doctest: +SKIP
41
+ >>> "my_trainer" in list_trainers() # doctest: +SKIP
42
+ True
43
+ """
44
+ _TRAINERS[name] = trainer_class
45
+
46
+
47
+ def get_trainer(name: str) -> type[BaseTrainer]:
48
+ """Get trainer class by name.
49
+
50
+ Parameters
51
+ ----------
52
+ name : str
53
+ Trainer name.
54
+
55
+ Returns
56
+ -------
57
+ type[BaseTrainer]
58
+ Trainer class.
59
+
60
+ Raises
61
+ ------
62
+ ValueError
63
+ If trainer name is not registered.
64
+
65
+ Examples
66
+ --------
67
+ >>> trainer_class = get_trainer("huggingface")
68
+ >>> trainer_class.__name__
69
+ 'HuggingFaceTrainer'
70
+ >>> get_trainer("unknown") # doctest: +SKIP
71
+ Traceback (most recent call last):
72
+ ...
73
+ ValueError: Unknown trainer: unknown. Available trainers: huggingface,
74
+ pytorch_lightning
75
+ """
76
+ if name not in _TRAINERS:
77
+ available = ", ".join(list_trainers())
78
+ msg = f"Unknown trainer: {name}. Available trainers: {available}"
79
+ raise ValueError(msg)
80
+ return _TRAINERS[name]
81
+
82
+
83
+ def list_trainers() -> list[str]:
84
+ """List available trainers.
85
+
86
+ Returns
87
+ -------
88
+ list[str]
89
+ List of registered trainer names.
90
+
91
+ Examples
92
+ --------
93
+ >>> trainers = list_trainers()
94
+ >>> "huggingface" in trainers
95
+ True
96
+ >>> "pytorch_lightning" in trainers
97
+ True
98
+ """
99
+ return list(_TRAINERS.keys())
100
+
101
+
102
+ # Register built-in trainers
103
+ register_trainer("huggingface", HuggingFaceTrainer)
104
+ register_trainer("pytorch_lightning", PyTorchLightningTrainer)
@@ -0,0 +1,11 @@
1
+ """Shared adapter utilities.
2
+
3
+ Provides base classes and utilities for integrating with external ML
4
+ frameworks like HuggingFace Transformers.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from bead.adapters.huggingface import HuggingFaceAdapterMixin
10
+
11
+ __all__ = ["HuggingFaceAdapterMixin"]
@@ -0,0 +1,61 @@
1
+ """Shared utilities for HuggingFace Transformers adapters.
2
+
3
+ This module provides common functionality for adapters that integrate with
4
+ HuggingFace Transformers models, including device validation and shared
5
+ utilities.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from typing import Literal
12
+
13
+ import torch.backends.mps
14
+ import torch.cuda
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ DeviceType = Literal["cpu", "cuda", "mps"]
19
+
20
+
21
+ def _cuda_available() -> bool:
22
+ """Check if CUDA is available."""
23
+ return torch.cuda.is_available() # pyright: ignore[reportAttributeAccessIssue]
24
+
25
+
26
+ def _mps_available() -> bool:
27
+ """Check if MPS (Apple Silicon) is available."""
28
+ return torch.backends.mps.is_available() # pyright: ignore[reportAttributeAccessIssue]
29
+
30
+
31
+ class HuggingFaceAdapterMixin:
32
+ """Mixin providing common HuggingFace adapter functionality.
33
+
34
+ This mixin provides device validation with automatic fallback.
35
+
36
+ Attributes
37
+ ----------
38
+ device : DeviceType
39
+ The validated device (cpu, cuda, or mps).
40
+ """
41
+
42
+ def _validate_device(self, device: DeviceType) -> DeviceType:
43
+ """Validate device and fallback if unavailable.
44
+
45
+ Parameters
46
+ ----------
47
+ device : DeviceType
48
+ Requested device.
49
+
50
+ Returns
51
+ -------
52
+ DeviceType
53
+ Validated device (falls back to CPU if unavailable).
54
+ """
55
+ if device == "cuda" and not _cuda_available():
56
+ logger.warning("CUDA unavailable, using CPU")
57
+ return "cpu"
58
+ if device == "mps" and not _mps_available():
59
+ logger.warning("MPS unavailable, using CPU")
60
+ return "cpu"
61
+ return device