bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1009 @@
1
+ """Active learning configuration models for the bead package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ from pydantic import BaseModel, Field, model_validator
9
+
10
+ from bead.active_learning.config import MixedEffectsConfig
11
+ from bead.data.range import Range
12
+
13
+
14
+ class BaseEncoderModelConfig(BaseModel):
15
+ """Base configuration for encoder-based active learning models.
16
+
17
+ Provides shared configuration fields for models that use transformer
18
+ encoders with optional dual-encoder architecture and mixed effects.
19
+
20
+ Parameters
21
+ ----------
22
+ model_name : str
23
+ HuggingFace model identifier.
24
+ max_length : int
25
+ Maximum sequence length for tokenization.
26
+ encoder_mode : Literal["single_encoder", "dual_encoder"]
27
+ Encoding strategy for input processing.
28
+ include_instructions : bool
29
+ Whether to include task instructions.
30
+ learning_rate : float
31
+ Learning rate for AdamW optimizer.
32
+ batch_size : int
33
+ Batch size for training.
34
+ num_epochs : int
35
+ Number of training epochs.
36
+ device : Literal["cpu", "cuda", "mps"]
37
+ Device to train on.
38
+ mixed_effects : MixedEffectsConfig
39
+ Mixed effects configuration for participant-level modeling.
40
+
41
+ Examples
42
+ --------
43
+ >>> config = BaseEncoderModelConfig()
44
+ >>> config.model_name
45
+ 'bert-base-uncased'
46
+ >>> config.batch_size
47
+ 16
48
+ >>> config.mixed_effects.mode
49
+ 'fixed'
50
+ """
51
+
52
+ model_name: str = Field(
53
+ default="bert-base-uncased",
54
+ description="HuggingFace model identifier",
55
+ )
56
+ max_length: int = Field(
57
+ default=128,
58
+ description="Maximum sequence length for tokenization",
59
+ gt=0,
60
+ )
61
+ encoder_mode: Literal["single_encoder", "dual_encoder"] = Field(
62
+ default="single_encoder",
63
+ description="Encoding strategy for input processing",
64
+ )
65
+ include_instructions: bool = Field(
66
+ default=False,
67
+ description="Whether to include task instructions",
68
+ )
69
+ learning_rate: float = Field(
70
+ default=2e-5,
71
+ description="Learning rate for AdamW optimizer",
72
+ gt=0,
73
+ )
74
+ batch_size: int = Field(
75
+ default=16,
76
+ description="Batch size for training",
77
+ gt=0,
78
+ )
79
+ num_epochs: int = Field(
80
+ default=3,
81
+ description="Number of training epochs",
82
+ gt=0,
83
+ )
84
+ device: Literal["cpu", "cuda", "mps"] = Field(
85
+ default="cpu",
86
+ description="Device to train on",
87
+ )
88
+ mixed_effects: MixedEffectsConfig = Field(
89
+ default_factory=MixedEffectsConfig,
90
+ description="Mixed effects configuration for participant-level modeling",
91
+ )
92
+
93
+
94
+ class ForcedChoiceModelConfig(BaseEncoderModelConfig):
95
+ """Configuration for forced choice active learning models.
96
+
97
+ Inherits all fields from BaseEncoderModelConfig. Used for tasks where
98
+ participants select one option from a set of alternatives.
99
+
100
+ Parameters
101
+ ----------
102
+ model_name : str
103
+ HuggingFace model identifier.
104
+ max_length : int
105
+ Maximum sequence length for tokenization.
106
+ encoder_mode : Literal["single_encoder", "dual_encoder"]
107
+ Encoding strategy for options.
108
+ include_instructions : bool
109
+ Whether to include task instructions.
110
+ learning_rate : float
111
+ Learning rate for AdamW optimizer.
112
+ batch_size : int
113
+ Batch size for training.
114
+ num_epochs : int
115
+ Number of training epochs.
116
+ device : Literal["cpu", "cuda", "mps"]
117
+ Device to train on.
118
+ mixed_effects : MixedEffectsConfig
119
+ Mixed effects configuration for participant-level modeling.
120
+
121
+ Examples
122
+ --------
123
+ >>> config = ForcedChoiceModelConfig()
124
+ >>> config.model_name
125
+ 'bert-base-uncased'
126
+ >>> config.batch_size
127
+ 16
128
+ >>> config.mixed_effects.mode
129
+ 'fixed'
130
+ """
131
+
132
+
133
+ class UncertaintySamplerConfig(BaseModel):
134
+ """Configuration for uncertainty sampling strategies.
135
+
136
+ Parameters
137
+ ----------
138
+ method : str
139
+ Uncertainty method to use ("entropy", "margin", "least_confidence").
140
+ batch_size : int | None
141
+ Number of items to select per iteration. If None, uses the
142
+ budget_per_iteration from ActiveLearningLoopConfig.
143
+
144
+ Examples
145
+ --------
146
+ >>> config = UncertaintySamplerConfig()
147
+ >>> config.method
148
+ 'entropy'
149
+ >>> config = UncertaintySamplerConfig(method="margin", batch_size=50)
150
+ >>> config.method
151
+ 'margin'
152
+ """
153
+
154
+ method: Literal["entropy", "margin", "least_confidence"] = Field(
155
+ default="entropy",
156
+ description="Uncertainty sampling method",
157
+ )
158
+ batch_size: int | None = Field(
159
+ default=None,
160
+ description="Number of items to select per iteration",
161
+ gt=0,
162
+ )
163
+
164
+
165
+ class JatosDataCollectionConfig(BaseModel):
166
+ """Configuration for JATOS data collection.
167
+
168
+ Parameters
169
+ ----------
170
+ base_url : str
171
+ JATOS base URL (e.g., "https://jatos.example.com").
172
+ api_token : str
173
+ JATOS API token for authentication.
174
+ study_id : int
175
+ JATOS study ID to collect data from.
176
+
177
+ Examples
178
+ --------
179
+ >>> config = JatosDataCollectionConfig(
180
+ ... base_url="https://jatos.example.com",
181
+ ... api_token="secret-token",
182
+ ... study_id=123,
183
+ ... )
184
+ >>> config.base_url
185
+ 'https://jatos.example.com'
186
+ """
187
+
188
+ base_url: str = Field(..., description="JATOS base URL")
189
+ api_token: str = Field(..., description="JATOS API token")
190
+ study_id: int = Field(..., description="JATOS study ID")
191
+
192
+
193
+ class ProlificDataCollectionConfig(BaseModel):
194
+ """Configuration for Prolific data collection.
195
+
196
+ Parameters
197
+ ----------
198
+ api_key : str
199
+ Prolific API key for authentication.
200
+ study_id : str
201
+ Prolific study ID to collect data from.
202
+
203
+ Examples
204
+ --------
205
+ >>> config = ProlificDataCollectionConfig(
206
+ ... api_key="secret-key",
207
+ ... study_id="abc123",
208
+ ... )
209
+ >>> config.study_id
210
+ 'abc123'
211
+ """
212
+
213
+ api_key: str = Field(..., description="Prolific API key")
214
+ study_id: str = Field(..., description="Prolific study ID")
215
+
216
+
217
+ class ActiveLearningLoopConfig(BaseModel):
218
+ """Configuration for active learning loop orchestration.
219
+
220
+ Parameters
221
+ ----------
222
+ max_iterations : int
223
+ Maximum number of AL iterations to run.
224
+ budget_per_iteration : int
225
+ Number of items to select per iteration.
226
+ stopping_criterion : str
227
+ Stopping criterion.
228
+ performance_threshold : float | None
229
+ Performance threshold for stopping.
230
+ metric_name : str
231
+ Metric name for convergence/threshold checks.
232
+ convergence_patience : int
233
+ Iterations to wait before declaring convergence.
234
+ convergence_threshold : float
235
+ Minimum improvement to avoid convergence.
236
+ jatos : JatosDataCollectionConfig | None
237
+ Configuration for JATOS data collection. If None, JATOS integration
238
+ is disabled.
239
+ prolific : ProlificDataCollectionConfig | None
240
+ Configuration for Prolific data collection. If None, Prolific
241
+ integration is disabled.
242
+ data_collection_timeout : int
243
+ Timeout in seconds for data collection.
244
+
245
+ Examples
246
+ --------
247
+ >>> config = ActiveLearningLoopConfig()
248
+ >>> config.max_iterations
249
+ 10
250
+ >>> config.budget_per_iteration
251
+ 100
252
+
253
+ >>> # With JATOS integration
254
+ >>> jatos_config = JatosDataCollectionConfig(
255
+ ... base_url="https://jatos.example.com",
256
+ ... api_token="secret-token",
257
+ ... study_id=123,
258
+ ... )
259
+ >>> config = ActiveLearningLoopConfig(jatos=jatos_config)
260
+ >>> config.jatos.study_id
261
+ 123
262
+ """
263
+
264
+ max_iterations: int = Field(
265
+ default=10,
266
+ description="Maximum number of iterations",
267
+ gt=0,
268
+ )
269
+ budget_per_iteration: int = Field(
270
+ default=100,
271
+ description="Number of items to select per iteration",
272
+ gt=0,
273
+ )
274
+ stopping_criterion: Literal[
275
+ "max_iterations", "convergence", "performance_threshold"
276
+ ] = Field(
277
+ default="max_iterations",
278
+ description="Stopping criterion for the loop",
279
+ )
280
+ performance_threshold: float | None = Field(
281
+ default=None,
282
+ description="Performance threshold for stopping",
283
+ ge=0,
284
+ le=1,
285
+ )
286
+ metric_name: str = Field(
287
+ default="accuracy",
288
+ description="Metric name for convergence/threshold checks",
289
+ )
290
+ convergence_patience: int = Field(
291
+ default=3,
292
+ description="Iterations to wait before declaring convergence",
293
+ gt=0,
294
+ )
295
+ convergence_threshold: float = Field(
296
+ default=0.01,
297
+ description="Minimum improvement to avoid convergence",
298
+ gt=0,
299
+ )
300
+ # data collection configuration (optional)
301
+ jatos: JatosDataCollectionConfig | None = Field(
302
+ default=None,
303
+ description="Configuration for JATOS data collection",
304
+ )
305
+ prolific: ProlificDataCollectionConfig | None = Field(
306
+ default=None,
307
+ description="Configuration for Prolific data collection",
308
+ )
309
+ data_collection_timeout: int = Field(
310
+ default=3600,
311
+ description="Timeout in seconds for data collection",
312
+ gt=0,
313
+ )
314
+
315
+
316
+ class TrainerConfig(BaseModel):
317
+ """Configuration for active learning trainers (HuggingFace, Lightning, etc.).
318
+
319
+ Parameters
320
+ ----------
321
+ trainer_type : str
322
+ Trainer type ("huggingface", "lightning").
323
+ epochs : int
324
+ Number of training epochs.
325
+ eval_strategy : str
326
+ Evaluation strategy.
327
+ save_strategy : str
328
+ Save strategy.
329
+ logging_dir : Path
330
+ Logging directory.
331
+ use_wandb : bool
332
+ Whether to use Weights & Biases.
333
+ wandb_project : str | None
334
+ W&B project name.
335
+
336
+ Examples
337
+ --------
338
+ >>> config = TrainerConfig()
339
+ >>> config.trainer_type
340
+ 'huggingface'
341
+ >>> config.epochs
342
+ 3
343
+ """
344
+
345
+ trainer_type: Literal["huggingface", "lightning"] = Field(
346
+ default="huggingface",
347
+ description="Trainer type",
348
+ )
349
+ epochs: int = Field(default=3, description="Training epochs", gt=0)
350
+ eval_strategy: str = Field(default="epoch", description="Evaluation strategy")
351
+ save_strategy: str = Field(default="epoch", description="Save strategy")
352
+ logging_dir: Path = Field(default=Path("logs"), description="Logging directory")
353
+ use_wandb: bool = Field(default=False, description="Use Weights & Biases")
354
+ wandb_project: str | None = Field(default=None, description="W&B project name")
355
+
356
+
357
+ class CategoricalModelConfig(BaseEncoderModelConfig):
358
+ """Configuration for categorical active learning models.
359
+
360
+ Inherits all fields from BaseEncoderModelConfig. Used for tasks where
361
+ participants select one category from a predefined set.
362
+
363
+ Parameters
364
+ ----------
365
+ model_name : str
366
+ HuggingFace model identifier.
367
+ max_length : int
368
+ Maximum sequence length for tokenization.
369
+ encoder_mode : Literal["single_encoder", "dual_encoder"]
370
+ Encoding strategy for categories.
371
+ include_instructions : bool
372
+ Whether to include task instructions.
373
+ learning_rate : float
374
+ Learning rate for AdamW optimizer.
375
+ batch_size : int
376
+ Batch size for training.
377
+ num_epochs : int
378
+ Number of training epochs.
379
+ device : Literal["cpu", "cuda", "mps"]
380
+ Device to train on.
381
+ mixed_effects : MixedEffectsConfig
382
+ Mixed effects configuration for participant-level modeling.
383
+
384
+ Examples
385
+ --------
386
+ >>> config = CategoricalModelConfig()
387
+ >>> config.model_name
388
+ 'bert-base-uncased'
389
+ >>> config.mixed_effects.mode
390
+ 'fixed'
391
+ """
392
+
393
+
394
+ class BinaryModelConfig(BaseEncoderModelConfig):
395
+ """Configuration for binary active learning models.
396
+
397
+ Inherits all fields from BaseEncoderModelConfig. Used for binary
398
+ classification tasks (yes/no, true/false, acceptable/unacceptable).
399
+
400
+ Parameters
401
+ ----------
402
+ model_name : str
403
+ HuggingFace model identifier.
404
+ max_length : int
405
+ Maximum sequence length for tokenization.
406
+ encoder_mode : Literal["single_encoder", "dual_encoder"]
407
+ Encoding strategy for binary classification.
408
+ include_instructions : bool
409
+ Whether to include task instructions.
410
+ learning_rate : float
411
+ Learning rate for AdamW optimizer.
412
+ batch_size : int
413
+ Batch size for training.
414
+ num_epochs : int
415
+ Number of training epochs.
416
+ device : Literal["cpu", "cuda", "mps"]
417
+ Device to train on.
418
+ mixed_effects : MixedEffectsConfig
419
+ Mixed effects configuration for participant-level modeling.
420
+
421
+ Examples
422
+ --------
423
+ >>> config = BinaryModelConfig()
424
+ >>> config.model_name
425
+ 'bert-base-uncased'
426
+ >>> config.mixed_effects.mode
427
+ 'fixed'
428
+ """
429
+
430
+
431
+ class MultiSelectModelConfig(BaseEncoderModelConfig):
432
+ """Configuration for multi-select active learning models.
433
+
434
+ Inherits all fields from BaseEncoderModelConfig. Used for tasks where
435
+ participants can select multiple options from a set of alternatives.
436
+
437
+ Parameters
438
+ ----------
439
+ model_name : str
440
+ HuggingFace model identifier.
441
+ max_length : int
442
+ Maximum sequence length for tokenization.
443
+ encoder_mode : Literal["single_encoder", "dual_encoder"]
444
+ Encoding strategy for multi-select options.
445
+ include_instructions : bool
446
+ Whether to include task instructions.
447
+ learning_rate : float
448
+ Learning rate for AdamW optimizer.
449
+ batch_size : int
450
+ Batch size for training.
451
+ num_epochs : int
452
+ Number of training epochs.
453
+ device : Literal["cpu", "cuda", "mps"]
454
+ Device to train on.
455
+ mixed_effects : MixedEffectsConfig
456
+ Mixed effects configuration for participant-level modeling.
457
+
458
+ Examples
459
+ --------
460
+ >>> config = MultiSelectModelConfig()
461
+ >>> config.model_name
462
+ 'bert-base-uncased'
463
+ >>> config.mixed_effects.mode
464
+ 'fixed'
465
+ """
466
+
467
+
468
+ class OrdinalScaleModelConfig(BaseModel):
469
+ """Configuration for ordinal scale active learning models.
470
+
471
+ Parameters
472
+ ----------
473
+ model_name : str
474
+ HuggingFace model identifier.
475
+ max_length : int
476
+ Maximum sequence length for tokenization.
477
+ encoder_mode : Literal["single_encoder"]
478
+ Encoding strategy for ordinal scale tasks.
479
+ include_instructions : bool
480
+ Whether to include task instructions.
481
+ learning_rate : float
482
+ Learning rate for AdamW optimizer.
483
+ batch_size : int
484
+ Batch size for training.
485
+ num_epochs : int
486
+ Number of training epochs.
487
+ device : Literal["cpu", "cuda", "mps"]
488
+ Device to train on.
489
+ scale : Range[float]
490
+ Numeric range for the ordinal scale (default: 0.0 to 1.0).
491
+ distribution : Literal["truncated_normal"]
492
+ Distribution for modeling bounded continuous responses.
493
+ sigma : float
494
+ Standard deviation for truncated normal distribution.
495
+ mixed_effects : MixedEffectsConfig
496
+ Mixed effects configuration for participant-level modeling.
497
+
498
+ Examples
499
+ --------
500
+ >>> config = OrdinalScaleModelConfig()
501
+ >>> config.model_name
502
+ 'bert-base-uncased'
503
+ >>> config.scale.min
504
+ 0.0
505
+ >>> config.scale.max
506
+ 1.0
507
+ >>> config.mixed_effects.mode
508
+ 'fixed'
509
+
510
+ >>> # Custom scale from 1.0 to 5.0
511
+ >>> config = OrdinalScaleModelConfig(
512
+ ... scale=Range[float](min=1.0, max=5.0)
513
+ ... )
514
+ >>> config.scale.contains(3.5)
515
+ True
516
+ """
517
+
518
+ model_name: str = Field(
519
+ default="bert-base-uncased",
520
+ description="HuggingFace model identifier",
521
+ )
522
+ max_length: int = Field(
523
+ default=128,
524
+ description="Maximum sequence length for tokenization",
525
+ gt=0,
526
+ )
527
+ encoder_mode: Literal["single_encoder"] = Field(
528
+ default="single_encoder",
529
+ description="Encoding strategy for ordinal scale tasks",
530
+ )
531
+ include_instructions: bool = Field(
532
+ default=False,
533
+ description="Whether to include task instructions",
534
+ )
535
+ learning_rate: float = Field(
536
+ default=2e-5,
537
+ description="Learning rate for AdamW optimizer",
538
+ gt=0,
539
+ )
540
+ batch_size: int = Field(
541
+ default=16,
542
+ description="Batch size for training",
543
+ gt=0,
544
+ )
545
+ num_epochs: int = Field(
546
+ default=3,
547
+ description="Number of training epochs",
548
+ gt=0,
549
+ )
550
+ device: Literal["cpu", "cuda", "mps"] = Field(
551
+ default="cpu",
552
+ description="Device to train on",
553
+ )
554
+ scale: Range[float] = Field(
555
+ default_factory=lambda: Range[float](min=0.0, max=1.0),
556
+ description="Numeric range for the ordinal scale",
557
+ )
558
+ distribution: Literal["truncated_normal"] = Field(
559
+ default="truncated_normal",
560
+ description="Distribution for modeling bounded continuous responses",
561
+ )
562
+ sigma: float = Field(
563
+ default=0.1,
564
+ description="Standard deviation for truncated normal distribution",
565
+ gt=0,
566
+ )
567
+ mixed_effects: MixedEffectsConfig = Field(
568
+ default_factory=MixedEffectsConfig,
569
+ description="Mixed effects configuration for participant-level modeling",
570
+ )
571
+
572
+
573
+ class MagnitudeModelConfig(BaseModel):
574
+ """Configuration for magnitude active learning models.
575
+
576
+ Parameters
577
+ ----------
578
+ model_name : str
579
+ HuggingFace model identifier.
580
+ max_length : int
581
+ Maximum sequence length for tokenization.
582
+ encoder_mode : Literal["single_encoder"]
583
+ Encoding strategy for magnitude tasks.
584
+ include_instructions : bool
585
+ Whether to include task instructions.
586
+ learning_rate : float
587
+ Learning rate for AdamW optimizer.
588
+ batch_size : int
589
+ Batch size for training.
590
+ num_epochs : int
591
+ Number of training epochs.
592
+ device : Literal["cpu", "cuda", "mps"]
593
+ Device to train on.
594
+ bounded : bool
595
+ Whether magnitude values are bounded to a range.
596
+ min_value : float | None
597
+ Minimum value (for bounded case). Required if bounded=True.
598
+ max_value : float | None
599
+ Maximum value (for bounded case). Required if bounded=True.
600
+ distribution : Literal["normal", "truncated_normal"]
601
+ Distribution for modeling responses.
602
+ "normal" for unbounded, "truncated_normal" for bounded.
603
+ sigma : float
604
+ Standard deviation for the distribution.
605
+ mixed_effects : MixedEffectsConfig
606
+ Mixed effects configuration for participant-level modeling.
607
+
608
+ Examples
609
+ --------
610
+ >>> # Unbounded magnitude (e.g., reading time)
611
+ >>> config = MagnitudeModelConfig(bounded=False, distribution="normal")
612
+ >>> config.bounded
613
+ False
614
+ >>> config.distribution
615
+ 'normal'
616
+
617
+ >>> # Bounded magnitude (e.g., confidence on 0-100 scale)
618
+ >>> config = MagnitudeModelConfig(
619
+ ... bounded=True,
620
+ ... min_value=0.0,
621
+ ... max_value=100.0,
622
+ ... distribution="truncated_normal"
623
+ ... )
624
+ >>> config.min_value
625
+ 0.0
626
+ """
627
+
628
+ model_name: str = Field(
629
+ default="bert-base-uncased",
630
+ description="HuggingFace model identifier",
631
+ )
632
+ max_length: int = Field(
633
+ default=128,
634
+ description="Maximum sequence length for tokenization",
635
+ gt=0,
636
+ )
637
+ encoder_mode: Literal["single_encoder"] = Field(
638
+ default="single_encoder",
639
+ description="Encoding strategy for magnitude tasks",
640
+ )
641
+ include_instructions: bool = Field(
642
+ default=False,
643
+ description="Whether to include task instructions",
644
+ )
645
+ learning_rate: float = Field(
646
+ default=2e-5,
647
+ description="Learning rate for AdamW optimizer",
648
+ gt=0,
649
+ )
650
+ batch_size: int = Field(
651
+ default=16,
652
+ description="Batch size for training",
653
+ gt=0,
654
+ )
655
+ num_epochs: int = Field(
656
+ default=3,
657
+ description="Number of training epochs",
658
+ gt=0,
659
+ )
660
+ device: Literal["cpu", "cuda", "mps"] = Field(
661
+ default="cpu",
662
+ description="Device to train on",
663
+ )
664
+ bounded: bool = Field(
665
+ default=False,
666
+ description="Whether magnitude values are bounded to a range",
667
+ )
668
+ min_value: float | None = Field(
669
+ default=None,
670
+ description="Minimum value (required if bounded=True)",
671
+ )
672
+ max_value: float | None = Field(
673
+ default=None,
674
+ description="Maximum value (required if bounded=True)",
675
+ )
676
+ distribution: Literal["normal", "truncated_normal"] = Field(
677
+ default="normal",
678
+ description="Distribution for modeling responses",
679
+ )
680
+ sigma: float = Field(
681
+ default=0.1,
682
+ description="Standard deviation for the distribution",
683
+ gt=0,
684
+ )
685
+ mixed_effects: MixedEffectsConfig = Field(
686
+ default_factory=MixedEffectsConfig,
687
+ description="Mixed effects configuration for participant-level modeling",
688
+ )
689
+
690
+ @model_validator(mode="after")
691
+ def validate_bounded_configuration(self) -> MagnitudeModelConfig:
692
+ """Validate bounded configuration consistency.
693
+
694
+ Raises
695
+ ------
696
+ ValueError
697
+ If bounded=True but min_value or max_value not set.
698
+ ValueError
699
+ If bounded=False but min_value or max_value is set.
700
+ ValueError
701
+ If min_value >= max_value.
702
+ ValueError
703
+ If distribution inconsistent with bounded setting.
704
+ """
705
+ if self.bounded:
706
+ if self.min_value is None or self.max_value is None:
707
+ raise ValueError(
708
+ "bounded=True requires both min_value and max_value to be set. "
709
+ f"Got min_value={self.min_value}, max_value={self.max_value}."
710
+ )
711
+ if self.min_value >= self.max_value:
712
+ raise ValueError(
713
+ f"min_value ({self.min_value}) must be less than "
714
+ f"max_value ({self.max_value})."
715
+ )
716
+ if self.distribution != "truncated_normal":
717
+ raise ValueError(
718
+ "bounded=True requires distribution='truncated_normal'. "
719
+ f"Got distribution='{self.distribution}'."
720
+ )
721
+ else:
722
+ if self.min_value is not None or self.max_value is not None:
723
+ raise ValueError(
724
+ "bounded=False but min_value or max_value is set. "
725
+ f"Got min_value={self.min_value}, max_value={self.max_value}. "
726
+ "Either set bounded=True or remove min_value/max_value."
727
+ )
728
+ if self.distribution != "normal":
729
+ raise ValueError(
730
+ "bounded=False requires distribution='normal'. "
731
+ f"Got distribution='{self.distribution}'."
732
+ )
733
+ return self
734
+
735
+
736
+ class FreeTextModelConfig(BaseModel):
737
+ """Configuration for free text generation with GLMM support.
738
+
739
+ Implements seq2seq generation with participant-level random effects using
740
+ LoRA (Low-Rank Adaptation) for random slopes mode.
741
+
742
+ Parameters
743
+ ----------
744
+ model_name : str
745
+ HuggingFace seq2seq model identifier (e.g., "t5-base", "facebook/bart-base").
746
+ max_input_length : int
747
+ Maximum input sequence length for tokenization.
748
+ max_output_length : int
749
+ Maximum output sequence length for generation.
750
+ num_beams : int
751
+ Beam search width (1 = greedy decoding).
752
+ temperature : float
753
+ Sampling temperature for generation.
754
+ top_p : float
755
+ Nucleus sampling probability cutoff.
756
+ learning_rate : float
757
+ Learning rate for AdamW optimizer.
758
+ batch_size : int
759
+ Batch size for training (typically smaller for seq2seq due to memory).
760
+ num_epochs : int
761
+ Number of training epochs.
762
+ device : Literal["cpu", "cuda", "mps"]
763
+ Device to train on.
764
+ lora_rank : int
765
+ LoRA rank r for low-rank decomposition (typical: 4-16).
766
+ lora_alpha : float
767
+ LoRA scaling factor α (typically 2*rank).
768
+ lora_dropout : float
769
+ Dropout probability for LoRA layers.
770
+ lora_target_modules : list[str]
771
+ Attention modules to apply LoRA (e.g., ["q_proj", "v_proj"]).
772
+ eval_metric : Literal["exact_match", "token_accuracy", "bleu"]
773
+ Evaluation metric for generation quality.
774
+ mixed_effects : MixedEffectsConfig
775
+ Mixed effects configuration for participant-level modeling.
776
+
777
+ Examples
778
+ --------
779
+ >>> config = FreeTextModelConfig()
780
+ >>> config.model_name
781
+ 't5-base'
782
+ >>> config.lora_rank
783
+ 8
784
+ >>> config.mixed_effects.mode
785
+ 'fixed'
786
+
787
+ >>> # With random slopes (LoRA)
788
+ >>> config = FreeTextModelConfig(
789
+ ... mixed_effects=MixedEffectsConfig(mode="random_slopes"),
790
+ ... lora_rank=8,
791
+ ... lora_alpha=16.0
792
+ ... )
793
+ """
794
+
795
+ model_name: str = Field(
796
+ default="t5-base",
797
+ description="HuggingFace seq2seq model identifier",
798
+ )
799
+ max_input_length: int = Field(
800
+ default=128,
801
+ description="Maximum input sequence length",
802
+ gt=0,
803
+ )
804
+ max_output_length: int = Field(
805
+ default=64,
806
+ description="Maximum output sequence length",
807
+ gt=0,
808
+ )
809
+ num_beams: int = Field(
810
+ default=4,
811
+ description="Beam search width (1 = greedy)",
812
+ gt=0,
813
+ )
814
+ temperature: float = Field(
815
+ default=1.0,
816
+ description="Sampling temperature",
817
+ gt=0.0,
818
+ )
819
+ top_p: float = Field(
820
+ default=0.9,
821
+ description="Nucleus sampling probability cutoff",
822
+ ge=0.0,
823
+ le=1.0,
824
+ )
825
+ learning_rate: float = Field(
826
+ default=2e-5,
827
+ description="Learning rate for AdamW optimizer",
828
+ gt=0,
829
+ )
830
+ batch_size: int = Field(
831
+ default=8,
832
+ description="Batch size for training",
833
+ gt=0,
834
+ )
835
+ num_epochs: int = Field(
836
+ default=3,
837
+ description="Number of training epochs",
838
+ gt=0,
839
+ )
840
+ device: Literal["cpu", "cuda", "mps"] = Field(
841
+ default="cpu",
842
+ description="Device to train on",
843
+ )
844
+ lora_rank: int = Field(
845
+ default=8,
846
+ description="LoRA rank r for low-rank decomposition",
847
+ gt=0,
848
+ )
849
+ lora_alpha: float = Field(
850
+ default=16.0,
851
+ description="LoRA scaling factor α",
852
+ gt=0,
853
+ )
854
+ lora_dropout: float = Field(
855
+ default=0.1,
856
+ description="Dropout probability for LoRA layers",
857
+ ge=0.0,
858
+ lt=1.0,
859
+ )
860
+ lora_target_modules: list[str] = Field(
861
+ default=["q", "v"],
862
+ description="Attention modules to apply LoRA to",
863
+ )
864
+ eval_metric: Literal["exact_match", "token_accuracy", "bleu"] = Field(
865
+ default="exact_match",
866
+ description="Evaluation metric for generation quality",
867
+ )
868
+ mixed_effects: MixedEffectsConfig = Field(
869
+ default_factory=MixedEffectsConfig,
870
+ description="Mixed effects configuration for participant-level modeling",
871
+ )
872
+
873
+
874
+ class ClozeModelConfig(BaseModel):
875
+ """Configuration for cloze (fill-in-the-blank) models with GLMM support.
876
+
877
+ Implements masked language modeling with participant-level random effects for
878
+ predicting tokens at unfilled slots in partially-filled templates.
879
+
880
+ Parameters
881
+ ----------
882
+ model_name : str
883
+ HuggingFace masked LM model identifier.
884
+ Examples: "bert-base-uncased", "roberta-base".
885
+ max_length : int
886
+ Maximum sequence length for tokenization.
887
+ learning_rate : float
888
+ Learning rate for AdamW optimizer.
889
+ batch_size : int
890
+ Batch size for training.
891
+ num_epochs : int
892
+ Number of training epochs.
893
+ device : Literal["cpu", "cuda", "mps"]
894
+ Device to train on.
895
+ mask_token : str
896
+ Token used for masking (model-specific, e.g., "[MASK]" for BERT).
897
+ eval_metric : Literal["exact_match", "token_accuracy"]
898
+ Evaluation metric for masked token prediction.
899
+ mixed_effects : MixedEffectsConfig
900
+ Mixed effects configuration for participant-level modeling.
901
+
902
+ Examples
903
+ --------
904
+ >>> config = ClozeModelConfig()
905
+ >>> config.model_name
906
+ 'bert-base-uncased'
907
+ >>> config.mask_token
908
+ '[MASK]'
909
+ >>> config.mixed_effects.mode
910
+ 'fixed'
911
+
912
+ >>> # With random intercepts
913
+ >>> config = ClozeModelConfig(
914
+ ... mixed_effects=MixedEffectsConfig(mode="random_intercepts"),
915
+ ... num_epochs=5
916
+ ... )
917
+ """
918
+
919
+ model_name: str = Field(
920
+ default="bert-base-uncased",
921
+ description="HuggingFace masked LM model identifier",
922
+ )
923
+ max_length: int = Field(
924
+ default=128,
925
+ description="Maximum sequence length for tokenization",
926
+ gt=0,
927
+ )
928
+ learning_rate: float = Field(
929
+ default=2e-5,
930
+ description="Learning rate for AdamW optimizer",
931
+ gt=0,
932
+ )
933
+ batch_size: int = Field(
934
+ default=16,
935
+ description="Batch size for training",
936
+ gt=0,
937
+ )
938
+ num_epochs: int = Field(
939
+ default=3,
940
+ description="Number of training epochs",
941
+ gt=0,
942
+ )
943
+ device: Literal["cpu", "cuda", "mps"] = Field(
944
+ default="cpu",
945
+ description="Device to train on",
946
+ )
947
+ mask_token: str = Field(
948
+ default="[MASK]",
949
+ description="Token used for masking (model-specific)",
950
+ )
951
+ eval_metric: Literal["exact_match", "token_accuracy"] = Field(
952
+ default="exact_match",
953
+ description="Evaluation metric for masked token prediction",
954
+ )
955
+ mixed_effects: MixedEffectsConfig = Field(
956
+ default_factory=MixedEffectsConfig,
957
+ description="Mixed effects configuration for participant-level modeling",
958
+ )
959
+
960
+
961
+ class ActiveLearningConfig(BaseModel):
962
+ """Configuration for active learning infrastructure.
963
+
964
+ Reflects the bead/active_learning/ module structure:
965
+ - models: Active learning models (ForcedChoiceModel, etc.)
966
+ - trainers: Training infrastructure (HuggingFace, Lightning)
967
+ - loop: Active learning loop orchestration
968
+ - selection: Item selection strategies (uncertainty sampling, etc.)
969
+
970
+ Parameters
971
+ ----------
972
+ forced_choice_model : ForcedChoiceModelConfig
973
+ Configuration for forced choice models.
974
+ trainer : TrainerConfig
975
+ Configuration for trainers (HuggingFace, Lightning).
976
+ loop : ActiveLearningLoopConfig
977
+ Configuration for active learning loop.
978
+ uncertainty_sampler : UncertaintySamplerConfig
979
+ Configuration for uncertainty sampling strategies.
980
+
981
+ Examples
982
+ --------
983
+ >>> config = ActiveLearningConfig()
984
+ >>> config.forced_choice_model.model_name
985
+ 'bert-base-uncased'
986
+ >>> config.trainer.trainer_type
987
+ 'huggingface'
988
+ >>> config.loop.max_iterations
989
+ 10
990
+ >>> config.uncertainty_sampler.method
991
+ 'entropy'
992
+ """
993
+
994
+ forced_choice_model: ForcedChoiceModelConfig = Field(
995
+ default_factory=ForcedChoiceModelConfig,
996
+ description="Forced choice model configuration",
997
+ )
998
+ trainer: TrainerConfig = Field(
999
+ default_factory=TrainerConfig,
1000
+ description="Trainer configuration",
1001
+ )
1002
+ loop: ActiveLearningLoopConfig = Field(
1003
+ default_factory=ActiveLearningLoopConfig,
1004
+ description="Active learning loop configuration",
1005
+ )
1006
+ uncertainty_sampler: UncertaintySamplerConfig = Field(
1007
+ default_factory=UncertaintySamplerConfig,
1008
+ description="Uncertainty sampler configuration",
1009
+ )