bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1009 @@
|
|
|
1
|
+
"""Active learning configuration models for the bead package."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, model_validator
|
|
9
|
+
|
|
10
|
+
from bead.active_learning.config import MixedEffectsConfig
|
|
11
|
+
from bead.data.range import Range
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseEncoderModelConfig(BaseModel):
|
|
15
|
+
"""Base configuration for encoder-based active learning models.
|
|
16
|
+
|
|
17
|
+
Provides shared configuration fields for models that use transformer
|
|
18
|
+
encoders with optional dual-encoder architecture and mixed effects.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
model_name : str
|
|
23
|
+
HuggingFace model identifier.
|
|
24
|
+
max_length : int
|
|
25
|
+
Maximum sequence length for tokenization.
|
|
26
|
+
encoder_mode : Literal["single_encoder", "dual_encoder"]
|
|
27
|
+
Encoding strategy for input processing.
|
|
28
|
+
include_instructions : bool
|
|
29
|
+
Whether to include task instructions.
|
|
30
|
+
learning_rate : float
|
|
31
|
+
Learning rate for AdamW optimizer.
|
|
32
|
+
batch_size : int
|
|
33
|
+
Batch size for training.
|
|
34
|
+
num_epochs : int
|
|
35
|
+
Number of training epochs.
|
|
36
|
+
device : Literal["cpu", "cuda", "mps"]
|
|
37
|
+
Device to train on.
|
|
38
|
+
mixed_effects : MixedEffectsConfig
|
|
39
|
+
Mixed effects configuration for participant-level modeling.
|
|
40
|
+
|
|
41
|
+
Examples
|
|
42
|
+
--------
|
|
43
|
+
>>> config = BaseEncoderModelConfig()
|
|
44
|
+
>>> config.model_name
|
|
45
|
+
'bert-base-uncased'
|
|
46
|
+
>>> config.batch_size
|
|
47
|
+
16
|
|
48
|
+
>>> config.mixed_effects.mode
|
|
49
|
+
'fixed'
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
model_name: str = Field(
|
|
53
|
+
default="bert-base-uncased",
|
|
54
|
+
description="HuggingFace model identifier",
|
|
55
|
+
)
|
|
56
|
+
max_length: int = Field(
|
|
57
|
+
default=128,
|
|
58
|
+
description="Maximum sequence length for tokenization",
|
|
59
|
+
gt=0,
|
|
60
|
+
)
|
|
61
|
+
encoder_mode: Literal["single_encoder", "dual_encoder"] = Field(
|
|
62
|
+
default="single_encoder",
|
|
63
|
+
description="Encoding strategy for input processing",
|
|
64
|
+
)
|
|
65
|
+
include_instructions: bool = Field(
|
|
66
|
+
default=False,
|
|
67
|
+
description="Whether to include task instructions",
|
|
68
|
+
)
|
|
69
|
+
learning_rate: float = Field(
|
|
70
|
+
default=2e-5,
|
|
71
|
+
description="Learning rate for AdamW optimizer",
|
|
72
|
+
gt=0,
|
|
73
|
+
)
|
|
74
|
+
batch_size: int = Field(
|
|
75
|
+
default=16,
|
|
76
|
+
description="Batch size for training",
|
|
77
|
+
gt=0,
|
|
78
|
+
)
|
|
79
|
+
num_epochs: int = Field(
|
|
80
|
+
default=3,
|
|
81
|
+
description="Number of training epochs",
|
|
82
|
+
gt=0,
|
|
83
|
+
)
|
|
84
|
+
device: Literal["cpu", "cuda", "mps"] = Field(
|
|
85
|
+
default="cpu",
|
|
86
|
+
description="Device to train on",
|
|
87
|
+
)
|
|
88
|
+
mixed_effects: MixedEffectsConfig = Field(
|
|
89
|
+
default_factory=MixedEffectsConfig,
|
|
90
|
+
description="Mixed effects configuration for participant-level modeling",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ForcedChoiceModelConfig(BaseEncoderModelConfig):
|
|
95
|
+
"""Configuration for forced choice active learning models.
|
|
96
|
+
|
|
97
|
+
Inherits all fields from BaseEncoderModelConfig. Used for tasks where
|
|
98
|
+
participants select one option from a set of alternatives.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
model_name : str
|
|
103
|
+
HuggingFace model identifier.
|
|
104
|
+
max_length : int
|
|
105
|
+
Maximum sequence length for tokenization.
|
|
106
|
+
encoder_mode : Literal["single_encoder", "dual_encoder"]
|
|
107
|
+
Encoding strategy for options.
|
|
108
|
+
include_instructions : bool
|
|
109
|
+
Whether to include task instructions.
|
|
110
|
+
learning_rate : float
|
|
111
|
+
Learning rate for AdamW optimizer.
|
|
112
|
+
batch_size : int
|
|
113
|
+
Batch size for training.
|
|
114
|
+
num_epochs : int
|
|
115
|
+
Number of training epochs.
|
|
116
|
+
device : Literal["cpu", "cuda", "mps"]
|
|
117
|
+
Device to train on.
|
|
118
|
+
mixed_effects : MixedEffectsConfig
|
|
119
|
+
Mixed effects configuration for participant-level modeling.
|
|
120
|
+
|
|
121
|
+
Examples
|
|
122
|
+
--------
|
|
123
|
+
>>> config = ForcedChoiceModelConfig()
|
|
124
|
+
>>> config.model_name
|
|
125
|
+
'bert-base-uncased'
|
|
126
|
+
>>> config.batch_size
|
|
127
|
+
16
|
|
128
|
+
>>> config.mixed_effects.mode
|
|
129
|
+
'fixed'
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class UncertaintySamplerConfig(BaseModel):
|
|
134
|
+
"""Configuration for uncertainty sampling strategies.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
method : str
|
|
139
|
+
Uncertainty method to use ("entropy", "margin", "least_confidence").
|
|
140
|
+
batch_size : int | None
|
|
141
|
+
Number of items to select per iteration. If None, uses the
|
|
142
|
+
budget_per_iteration from ActiveLearningLoopConfig.
|
|
143
|
+
|
|
144
|
+
Examples
|
|
145
|
+
--------
|
|
146
|
+
>>> config = UncertaintySamplerConfig()
|
|
147
|
+
>>> config.method
|
|
148
|
+
'entropy'
|
|
149
|
+
>>> config = UncertaintySamplerConfig(method="margin", batch_size=50)
|
|
150
|
+
>>> config.method
|
|
151
|
+
'margin'
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
method: Literal["entropy", "margin", "least_confidence"] = Field(
|
|
155
|
+
default="entropy",
|
|
156
|
+
description="Uncertainty sampling method",
|
|
157
|
+
)
|
|
158
|
+
batch_size: int | None = Field(
|
|
159
|
+
default=None,
|
|
160
|
+
description="Number of items to select per iteration",
|
|
161
|
+
gt=0,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class JatosDataCollectionConfig(BaseModel):
|
|
166
|
+
"""Configuration for JATOS data collection.
|
|
167
|
+
|
|
168
|
+
Parameters
|
|
169
|
+
----------
|
|
170
|
+
base_url : str
|
|
171
|
+
JATOS base URL (e.g., "https://jatos.example.com").
|
|
172
|
+
api_token : str
|
|
173
|
+
JATOS API token for authentication.
|
|
174
|
+
study_id : int
|
|
175
|
+
JATOS study ID to collect data from.
|
|
176
|
+
|
|
177
|
+
Examples
|
|
178
|
+
--------
|
|
179
|
+
>>> config = JatosDataCollectionConfig(
|
|
180
|
+
... base_url="https://jatos.example.com",
|
|
181
|
+
... api_token="secret-token",
|
|
182
|
+
... study_id=123,
|
|
183
|
+
... )
|
|
184
|
+
>>> config.base_url
|
|
185
|
+
'https://jatos.example.com'
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
base_url: str = Field(..., description="JATOS base URL")
|
|
189
|
+
api_token: str = Field(..., description="JATOS API token")
|
|
190
|
+
study_id: int = Field(..., description="JATOS study ID")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class ProlificDataCollectionConfig(BaseModel):
|
|
194
|
+
"""Configuration for Prolific data collection.
|
|
195
|
+
|
|
196
|
+
Parameters
|
|
197
|
+
----------
|
|
198
|
+
api_key : str
|
|
199
|
+
Prolific API key for authentication.
|
|
200
|
+
study_id : str
|
|
201
|
+
Prolific study ID to collect data from.
|
|
202
|
+
|
|
203
|
+
Examples
|
|
204
|
+
--------
|
|
205
|
+
>>> config = ProlificDataCollectionConfig(
|
|
206
|
+
... api_key="secret-key",
|
|
207
|
+
... study_id="abc123",
|
|
208
|
+
... )
|
|
209
|
+
>>> config.study_id
|
|
210
|
+
'abc123'
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
api_key: str = Field(..., description="Prolific API key")
|
|
214
|
+
study_id: str = Field(..., description="Prolific study ID")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class ActiveLearningLoopConfig(BaseModel):
|
|
218
|
+
"""Configuration for active learning loop orchestration.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
max_iterations : int
|
|
223
|
+
Maximum number of AL iterations to run.
|
|
224
|
+
budget_per_iteration : int
|
|
225
|
+
Number of items to select per iteration.
|
|
226
|
+
stopping_criterion : str
|
|
227
|
+
Stopping criterion.
|
|
228
|
+
performance_threshold : float | None
|
|
229
|
+
Performance threshold for stopping.
|
|
230
|
+
metric_name : str
|
|
231
|
+
Metric name for convergence/threshold checks.
|
|
232
|
+
convergence_patience : int
|
|
233
|
+
Iterations to wait before declaring convergence.
|
|
234
|
+
convergence_threshold : float
|
|
235
|
+
Minimum improvement to avoid convergence.
|
|
236
|
+
jatos : JatosDataCollectionConfig | None
|
|
237
|
+
Configuration for JATOS data collection. If None, JATOS integration
|
|
238
|
+
is disabled.
|
|
239
|
+
prolific : ProlificDataCollectionConfig | None
|
|
240
|
+
Configuration for Prolific data collection. If None, Prolific
|
|
241
|
+
integration is disabled.
|
|
242
|
+
data_collection_timeout : int
|
|
243
|
+
Timeout in seconds for data collection.
|
|
244
|
+
|
|
245
|
+
Examples
|
|
246
|
+
--------
|
|
247
|
+
>>> config = ActiveLearningLoopConfig()
|
|
248
|
+
>>> config.max_iterations
|
|
249
|
+
10
|
|
250
|
+
>>> config.budget_per_iteration
|
|
251
|
+
100
|
|
252
|
+
|
|
253
|
+
>>> # With JATOS integration
|
|
254
|
+
>>> jatos_config = JatosDataCollectionConfig(
|
|
255
|
+
... base_url="https://jatos.example.com",
|
|
256
|
+
... api_token="secret-token",
|
|
257
|
+
... study_id=123,
|
|
258
|
+
... )
|
|
259
|
+
>>> config = ActiveLearningLoopConfig(jatos=jatos_config)
|
|
260
|
+
>>> config.jatos.study_id
|
|
261
|
+
123
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
max_iterations: int = Field(
|
|
265
|
+
default=10,
|
|
266
|
+
description="Maximum number of iterations",
|
|
267
|
+
gt=0,
|
|
268
|
+
)
|
|
269
|
+
budget_per_iteration: int = Field(
|
|
270
|
+
default=100,
|
|
271
|
+
description="Number of items to select per iteration",
|
|
272
|
+
gt=0,
|
|
273
|
+
)
|
|
274
|
+
stopping_criterion: Literal[
|
|
275
|
+
"max_iterations", "convergence", "performance_threshold"
|
|
276
|
+
] = Field(
|
|
277
|
+
default="max_iterations",
|
|
278
|
+
description="Stopping criterion for the loop",
|
|
279
|
+
)
|
|
280
|
+
performance_threshold: float | None = Field(
|
|
281
|
+
default=None,
|
|
282
|
+
description="Performance threshold for stopping",
|
|
283
|
+
ge=0,
|
|
284
|
+
le=1,
|
|
285
|
+
)
|
|
286
|
+
metric_name: str = Field(
|
|
287
|
+
default="accuracy",
|
|
288
|
+
description="Metric name for convergence/threshold checks",
|
|
289
|
+
)
|
|
290
|
+
convergence_patience: int = Field(
|
|
291
|
+
default=3,
|
|
292
|
+
description="Iterations to wait before declaring convergence",
|
|
293
|
+
gt=0,
|
|
294
|
+
)
|
|
295
|
+
convergence_threshold: float = Field(
|
|
296
|
+
default=0.01,
|
|
297
|
+
description="Minimum improvement to avoid convergence",
|
|
298
|
+
gt=0,
|
|
299
|
+
)
|
|
300
|
+
# data collection configuration (optional)
|
|
301
|
+
jatos: JatosDataCollectionConfig | None = Field(
|
|
302
|
+
default=None,
|
|
303
|
+
description="Configuration for JATOS data collection",
|
|
304
|
+
)
|
|
305
|
+
prolific: ProlificDataCollectionConfig | None = Field(
|
|
306
|
+
default=None,
|
|
307
|
+
description="Configuration for Prolific data collection",
|
|
308
|
+
)
|
|
309
|
+
data_collection_timeout: int = Field(
|
|
310
|
+
default=3600,
|
|
311
|
+
description="Timeout in seconds for data collection",
|
|
312
|
+
gt=0,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class TrainerConfig(BaseModel):
|
|
317
|
+
"""Configuration for active learning trainers (HuggingFace, Lightning, etc.).
|
|
318
|
+
|
|
319
|
+
Parameters
|
|
320
|
+
----------
|
|
321
|
+
trainer_type : str
|
|
322
|
+
Trainer type ("huggingface", "lightning").
|
|
323
|
+
epochs : int
|
|
324
|
+
Number of training epochs.
|
|
325
|
+
eval_strategy : str
|
|
326
|
+
Evaluation strategy.
|
|
327
|
+
save_strategy : str
|
|
328
|
+
Save strategy.
|
|
329
|
+
logging_dir : Path
|
|
330
|
+
Logging directory.
|
|
331
|
+
use_wandb : bool
|
|
332
|
+
Whether to use Weights & Biases.
|
|
333
|
+
wandb_project : str | None
|
|
334
|
+
W&B project name.
|
|
335
|
+
|
|
336
|
+
Examples
|
|
337
|
+
--------
|
|
338
|
+
>>> config = TrainerConfig()
|
|
339
|
+
>>> config.trainer_type
|
|
340
|
+
'huggingface'
|
|
341
|
+
>>> config.epochs
|
|
342
|
+
3
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
trainer_type: Literal["huggingface", "lightning"] = Field(
|
|
346
|
+
default="huggingface",
|
|
347
|
+
description="Trainer type",
|
|
348
|
+
)
|
|
349
|
+
epochs: int = Field(default=3, description="Training epochs", gt=0)
|
|
350
|
+
eval_strategy: str = Field(default="epoch", description="Evaluation strategy")
|
|
351
|
+
save_strategy: str = Field(default="epoch", description="Save strategy")
|
|
352
|
+
logging_dir: Path = Field(default=Path("logs"), description="Logging directory")
|
|
353
|
+
use_wandb: bool = Field(default=False, description="Use Weights & Biases")
|
|
354
|
+
wandb_project: str | None = Field(default=None, description="W&B project name")
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class CategoricalModelConfig(BaseEncoderModelConfig):
|
|
358
|
+
"""Configuration for categorical active learning models.
|
|
359
|
+
|
|
360
|
+
Inherits all fields from BaseEncoderModelConfig. Used for tasks where
|
|
361
|
+
participants select one category from a predefined set.
|
|
362
|
+
|
|
363
|
+
Parameters
|
|
364
|
+
----------
|
|
365
|
+
model_name : str
|
|
366
|
+
HuggingFace model identifier.
|
|
367
|
+
max_length : int
|
|
368
|
+
Maximum sequence length for tokenization.
|
|
369
|
+
encoder_mode : Literal["single_encoder", "dual_encoder"]
|
|
370
|
+
Encoding strategy for categories.
|
|
371
|
+
include_instructions : bool
|
|
372
|
+
Whether to include task instructions.
|
|
373
|
+
learning_rate : float
|
|
374
|
+
Learning rate for AdamW optimizer.
|
|
375
|
+
batch_size : int
|
|
376
|
+
Batch size for training.
|
|
377
|
+
num_epochs : int
|
|
378
|
+
Number of training epochs.
|
|
379
|
+
device : Literal["cpu", "cuda", "mps"]
|
|
380
|
+
Device to train on.
|
|
381
|
+
mixed_effects : MixedEffectsConfig
|
|
382
|
+
Mixed effects configuration for participant-level modeling.
|
|
383
|
+
|
|
384
|
+
Examples
|
|
385
|
+
--------
|
|
386
|
+
>>> config = CategoricalModelConfig()
|
|
387
|
+
>>> config.model_name
|
|
388
|
+
'bert-base-uncased'
|
|
389
|
+
>>> config.mixed_effects.mode
|
|
390
|
+
'fixed'
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class BinaryModelConfig(BaseEncoderModelConfig):
|
|
395
|
+
"""Configuration for binary active learning models.
|
|
396
|
+
|
|
397
|
+
Inherits all fields from BaseEncoderModelConfig. Used for binary
|
|
398
|
+
classification tasks (yes/no, true/false, acceptable/unacceptable).
|
|
399
|
+
|
|
400
|
+
Parameters
|
|
401
|
+
----------
|
|
402
|
+
model_name : str
|
|
403
|
+
HuggingFace model identifier.
|
|
404
|
+
max_length : int
|
|
405
|
+
Maximum sequence length for tokenization.
|
|
406
|
+
encoder_mode : Literal["single_encoder", "dual_encoder"]
|
|
407
|
+
Encoding strategy for binary classification.
|
|
408
|
+
include_instructions : bool
|
|
409
|
+
Whether to include task instructions.
|
|
410
|
+
learning_rate : float
|
|
411
|
+
Learning rate for AdamW optimizer.
|
|
412
|
+
batch_size : int
|
|
413
|
+
Batch size for training.
|
|
414
|
+
num_epochs : int
|
|
415
|
+
Number of training epochs.
|
|
416
|
+
device : Literal["cpu", "cuda", "mps"]
|
|
417
|
+
Device to train on.
|
|
418
|
+
mixed_effects : MixedEffectsConfig
|
|
419
|
+
Mixed effects configuration for participant-level modeling.
|
|
420
|
+
|
|
421
|
+
Examples
|
|
422
|
+
--------
|
|
423
|
+
>>> config = BinaryModelConfig()
|
|
424
|
+
>>> config.model_name
|
|
425
|
+
'bert-base-uncased'
|
|
426
|
+
>>> config.mixed_effects.mode
|
|
427
|
+
'fixed'
|
|
428
|
+
"""
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class MultiSelectModelConfig(BaseEncoderModelConfig):
|
|
432
|
+
"""Configuration for multi-select active learning models.
|
|
433
|
+
|
|
434
|
+
Inherits all fields from BaseEncoderModelConfig. Used for tasks where
|
|
435
|
+
participants can select multiple options from a set of alternatives.
|
|
436
|
+
|
|
437
|
+
Parameters
|
|
438
|
+
----------
|
|
439
|
+
model_name : str
|
|
440
|
+
HuggingFace model identifier.
|
|
441
|
+
max_length : int
|
|
442
|
+
Maximum sequence length for tokenization.
|
|
443
|
+
encoder_mode : Literal["single_encoder", "dual_encoder"]
|
|
444
|
+
Encoding strategy for multi-select options.
|
|
445
|
+
include_instructions : bool
|
|
446
|
+
Whether to include task instructions.
|
|
447
|
+
learning_rate : float
|
|
448
|
+
Learning rate for AdamW optimizer.
|
|
449
|
+
batch_size : int
|
|
450
|
+
Batch size for training.
|
|
451
|
+
num_epochs : int
|
|
452
|
+
Number of training epochs.
|
|
453
|
+
device : Literal["cpu", "cuda", "mps"]
|
|
454
|
+
Device to train on.
|
|
455
|
+
mixed_effects : MixedEffectsConfig
|
|
456
|
+
Mixed effects configuration for participant-level modeling.
|
|
457
|
+
|
|
458
|
+
Examples
|
|
459
|
+
--------
|
|
460
|
+
>>> config = MultiSelectModelConfig()
|
|
461
|
+
>>> config.model_name
|
|
462
|
+
'bert-base-uncased'
|
|
463
|
+
>>> config.mixed_effects.mode
|
|
464
|
+
'fixed'
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class OrdinalScaleModelConfig(BaseModel):
|
|
469
|
+
"""Configuration for ordinal scale active learning models.
|
|
470
|
+
|
|
471
|
+
Parameters
|
|
472
|
+
----------
|
|
473
|
+
model_name : str
|
|
474
|
+
HuggingFace model identifier.
|
|
475
|
+
max_length : int
|
|
476
|
+
Maximum sequence length for tokenization.
|
|
477
|
+
encoder_mode : Literal["single_encoder"]
|
|
478
|
+
Encoding strategy for ordinal scale tasks.
|
|
479
|
+
include_instructions : bool
|
|
480
|
+
Whether to include task instructions.
|
|
481
|
+
learning_rate : float
|
|
482
|
+
Learning rate for AdamW optimizer.
|
|
483
|
+
batch_size : int
|
|
484
|
+
Batch size for training.
|
|
485
|
+
num_epochs : int
|
|
486
|
+
Number of training epochs.
|
|
487
|
+
device : Literal["cpu", "cuda", "mps"]
|
|
488
|
+
Device to train on.
|
|
489
|
+
scale : Range[float]
|
|
490
|
+
Numeric range for the ordinal scale (default: 0.0 to 1.0).
|
|
491
|
+
distribution : Literal["truncated_normal"]
|
|
492
|
+
Distribution for modeling bounded continuous responses.
|
|
493
|
+
sigma : float
|
|
494
|
+
Standard deviation for truncated normal distribution.
|
|
495
|
+
mixed_effects : MixedEffectsConfig
|
|
496
|
+
Mixed effects configuration for participant-level modeling.
|
|
497
|
+
|
|
498
|
+
Examples
|
|
499
|
+
--------
|
|
500
|
+
>>> config = OrdinalScaleModelConfig()
|
|
501
|
+
>>> config.model_name
|
|
502
|
+
'bert-base-uncased'
|
|
503
|
+
>>> config.scale.min
|
|
504
|
+
0.0
|
|
505
|
+
>>> config.scale.max
|
|
506
|
+
1.0
|
|
507
|
+
>>> config.mixed_effects.mode
|
|
508
|
+
'fixed'
|
|
509
|
+
|
|
510
|
+
>>> # Custom scale from 1.0 to 5.0
|
|
511
|
+
>>> config = OrdinalScaleModelConfig(
|
|
512
|
+
... scale=Range[float](min=1.0, max=5.0)
|
|
513
|
+
... )
|
|
514
|
+
>>> config.scale.contains(3.5)
|
|
515
|
+
True
|
|
516
|
+
"""
|
|
517
|
+
|
|
518
|
+
model_name: str = Field(
|
|
519
|
+
default="bert-base-uncased",
|
|
520
|
+
description="HuggingFace model identifier",
|
|
521
|
+
)
|
|
522
|
+
max_length: int = Field(
|
|
523
|
+
default=128,
|
|
524
|
+
description="Maximum sequence length for tokenization",
|
|
525
|
+
gt=0,
|
|
526
|
+
)
|
|
527
|
+
encoder_mode: Literal["single_encoder"] = Field(
|
|
528
|
+
default="single_encoder",
|
|
529
|
+
description="Encoding strategy for ordinal scale tasks",
|
|
530
|
+
)
|
|
531
|
+
include_instructions: bool = Field(
|
|
532
|
+
default=False,
|
|
533
|
+
description="Whether to include task instructions",
|
|
534
|
+
)
|
|
535
|
+
learning_rate: float = Field(
|
|
536
|
+
default=2e-5,
|
|
537
|
+
description="Learning rate for AdamW optimizer",
|
|
538
|
+
gt=0,
|
|
539
|
+
)
|
|
540
|
+
batch_size: int = Field(
|
|
541
|
+
default=16,
|
|
542
|
+
description="Batch size for training",
|
|
543
|
+
gt=0,
|
|
544
|
+
)
|
|
545
|
+
num_epochs: int = Field(
|
|
546
|
+
default=3,
|
|
547
|
+
description="Number of training epochs",
|
|
548
|
+
gt=0,
|
|
549
|
+
)
|
|
550
|
+
device: Literal["cpu", "cuda", "mps"] = Field(
|
|
551
|
+
default="cpu",
|
|
552
|
+
description="Device to train on",
|
|
553
|
+
)
|
|
554
|
+
scale: Range[float] = Field(
|
|
555
|
+
default_factory=lambda: Range[float](min=0.0, max=1.0),
|
|
556
|
+
description="Numeric range for the ordinal scale",
|
|
557
|
+
)
|
|
558
|
+
distribution: Literal["truncated_normal"] = Field(
|
|
559
|
+
default="truncated_normal",
|
|
560
|
+
description="Distribution for modeling bounded continuous responses",
|
|
561
|
+
)
|
|
562
|
+
sigma: float = Field(
|
|
563
|
+
default=0.1,
|
|
564
|
+
description="Standard deviation for truncated normal distribution",
|
|
565
|
+
gt=0,
|
|
566
|
+
)
|
|
567
|
+
mixed_effects: MixedEffectsConfig = Field(
|
|
568
|
+
default_factory=MixedEffectsConfig,
|
|
569
|
+
description="Mixed effects configuration for participant-level modeling",
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
class MagnitudeModelConfig(BaseModel):
|
|
574
|
+
"""Configuration for magnitude active learning models.
|
|
575
|
+
|
|
576
|
+
Parameters
|
|
577
|
+
----------
|
|
578
|
+
model_name : str
|
|
579
|
+
HuggingFace model identifier.
|
|
580
|
+
max_length : int
|
|
581
|
+
Maximum sequence length for tokenization.
|
|
582
|
+
encoder_mode : Literal["single_encoder"]
|
|
583
|
+
Encoding strategy for magnitude tasks.
|
|
584
|
+
include_instructions : bool
|
|
585
|
+
Whether to include task instructions.
|
|
586
|
+
learning_rate : float
|
|
587
|
+
Learning rate for AdamW optimizer.
|
|
588
|
+
batch_size : int
|
|
589
|
+
Batch size for training.
|
|
590
|
+
num_epochs : int
|
|
591
|
+
Number of training epochs.
|
|
592
|
+
device : Literal["cpu", "cuda", "mps"]
|
|
593
|
+
Device to train on.
|
|
594
|
+
bounded : bool
|
|
595
|
+
Whether magnitude values are bounded to a range.
|
|
596
|
+
min_value : float | None
|
|
597
|
+
Minimum value (for bounded case). Required if bounded=True.
|
|
598
|
+
max_value : float | None
|
|
599
|
+
Maximum value (for bounded case). Required if bounded=True.
|
|
600
|
+
distribution : Literal["normal", "truncated_normal"]
|
|
601
|
+
Distribution for modeling responses.
|
|
602
|
+
"normal" for unbounded, "truncated_normal" for bounded.
|
|
603
|
+
sigma : float
|
|
604
|
+
Standard deviation for the distribution.
|
|
605
|
+
mixed_effects : MixedEffectsConfig
|
|
606
|
+
Mixed effects configuration for participant-level modeling.
|
|
607
|
+
|
|
608
|
+
Examples
|
|
609
|
+
--------
|
|
610
|
+
>>> # Unbounded magnitude (e.g., reading time)
|
|
611
|
+
>>> config = MagnitudeModelConfig(bounded=False, distribution="normal")
|
|
612
|
+
>>> config.bounded
|
|
613
|
+
False
|
|
614
|
+
>>> config.distribution
|
|
615
|
+
'normal'
|
|
616
|
+
|
|
617
|
+
>>> # Bounded magnitude (e.g., confidence on 0-100 scale)
|
|
618
|
+
>>> config = MagnitudeModelConfig(
|
|
619
|
+
... bounded=True,
|
|
620
|
+
... min_value=0.0,
|
|
621
|
+
... max_value=100.0,
|
|
622
|
+
... distribution="truncated_normal"
|
|
623
|
+
... )
|
|
624
|
+
>>> config.min_value
|
|
625
|
+
0.0
|
|
626
|
+
"""
|
|
627
|
+
|
|
628
|
+
model_name: str = Field(
|
|
629
|
+
default="bert-base-uncased",
|
|
630
|
+
description="HuggingFace model identifier",
|
|
631
|
+
)
|
|
632
|
+
max_length: int = Field(
|
|
633
|
+
default=128,
|
|
634
|
+
description="Maximum sequence length for tokenization",
|
|
635
|
+
gt=0,
|
|
636
|
+
)
|
|
637
|
+
encoder_mode: Literal["single_encoder"] = Field(
|
|
638
|
+
default="single_encoder",
|
|
639
|
+
description="Encoding strategy for magnitude tasks",
|
|
640
|
+
)
|
|
641
|
+
include_instructions: bool = Field(
|
|
642
|
+
default=False,
|
|
643
|
+
description="Whether to include task instructions",
|
|
644
|
+
)
|
|
645
|
+
learning_rate: float = Field(
|
|
646
|
+
default=2e-5,
|
|
647
|
+
description="Learning rate for AdamW optimizer",
|
|
648
|
+
gt=0,
|
|
649
|
+
)
|
|
650
|
+
batch_size: int = Field(
|
|
651
|
+
default=16,
|
|
652
|
+
description="Batch size for training",
|
|
653
|
+
gt=0,
|
|
654
|
+
)
|
|
655
|
+
num_epochs: int = Field(
|
|
656
|
+
default=3,
|
|
657
|
+
description="Number of training epochs",
|
|
658
|
+
gt=0,
|
|
659
|
+
)
|
|
660
|
+
device: Literal["cpu", "cuda", "mps"] = Field(
|
|
661
|
+
default="cpu",
|
|
662
|
+
description="Device to train on",
|
|
663
|
+
)
|
|
664
|
+
bounded: bool = Field(
|
|
665
|
+
default=False,
|
|
666
|
+
description="Whether magnitude values are bounded to a range",
|
|
667
|
+
)
|
|
668
|
+
min_value: float | None = Field(
|
|
669
|
+
default=None,
|
|
670
|
+
description="Minimum value (required if bounded=True)",
|
|
671
|
+
)
|
|
672
|
+
max_value: float | None = Field(
|
|
673
|
+
default=None,
|
|
674
|
+
description="Maximum value (required if bounded=True)",
|
|
675
|
+
)
|
|
676
|
+
distribution: Literal["normal", "truncated_normal"] = Field(
|
|
677
|
+
default="normal",
|
|
678
|
+
description="Distribution for modeling responses",
|
|
679
|
+
)
|
|
680
|
+
sigma: float = Field(
|
|
681
|
+
default=0.1,
|
|
682
|
+
description="Standard deviation for the distribution",
|
|
683
|
+
gt=0,
|
|
684
|
+
)
|
|
685
|
+
mixed_effects: MixedEffectsConfig = Field(
|
|
686
|
+
default_factory=MixedEffectsConfig,
|
|
687
|
+
description="Mixed effects configuration for participant-level modeling",
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
@model_validator(mode="after")
|
|
691
|
+
def validate_bounded_configuration(self) -> MagnitudeModelConfig:
|
|
692
|
+
"""Validate bounded configuration consistency.
|
|
693
|
+
|
|
694
|
+
Raises
|
|
695
|
+
------
|
|
696
|
+
ValueError
|
|
697
|
+
If bounded=True but min_value or max_value not set.
|
|
698
|
+
ValueError
|
|
699
|
+
If bounded=False but min_value or max_value is set.
|
|
700
|
+
ValueError
|
|
701
|
+
If min_value >= max_value.
|
|
702
|
+
ValueError
|
|
703
|
+
If distribution inconsistent with bounded setting.
|
|
704
|
+
"""
|
|
705
|
+
if self.bounded:
|
|
706
|
+
if self.min_value is None or self.max_value is None:
|
|
707
|
+
raise ValueError(
|
|
708
|
+
"bounded=True requires both min_value and max_value to be set. "
|
|
709
|
+
f"Got min_value={self.min_value}, max_value={self.max_value}."
|
|
710
|
+
)
|
|
711
|
+
if self.min_value >= self.max_value:
|
|
712
|
+
raise ValueError(
|
|
713
|
+
f"min_value ({self.min_value}) must be less than "
|
|
714
|
+
f"max_value ({self.max_value})."
|
|
715
|
+
)
|
|
716
|
+
if self.distribution != "truncated_normal":
|
|
717
|
+
raise ValueError(
|
|
718
|
+
"bounded=True requires distribution='truncated_normal'. "
|
|
719
|
+
f"Got distribution='{self.distribution}'."
|
|
720
|
+
)
|
|
721
|
+
else:
|
|
722
|
+
if self.min_value is not None or self.max_value is not None:
|
|
723
|
+
raise ValueError(
|
|
724
|
+
"bounded=False but min_value or max_value is set. "
|
|
725
|
+
f"Got min_value={self.min_value}, max_value={self.max_value}. "
|
|
726
|
+
"Either set bounded=True or remove min_value/max_value."
|
|
727
|
+
)
|
|
728
|
+
if self.distribution != "normal":
|
|
729
|
+
raise ValueError(
|
|
730
|
+
"bounded=False requires distribution='normal'. "
|
|
731
|
+
f"Got distribution='{self.distribution}'."
|
|
732
|
+
)
|
|
733
|
+
return self
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
class FreeTextModelConfig(BaseModel):
|
|
737
|
+
"""Configuration for free text generation with GLMM support.
|
|
738
|
+
|
|
739
|
+
Implements seq2seq generation with participant-level random effects using
|
|
740
|
+
LoRA (Low-Rank Adaptation) for random slopes mode.
|
|
741
|
+
|
|
742
|
+
Parameters
|
|
743
|
+
----------
|
|
744
|
+
model_name : str
|
|
745
|
+
HuggingFace seq2seq model identifier (e.g., "t5-base", "facebook/bart-base").
|
|
746
|
+
max_input_length : int
|
|
747
|
+
Maximum input sequence length for tokenization.
|
|
748
|
+
max_output_length : int
|
|
749
|
+
Maximum output sequence length for generation.
|
|
750
|
+
num_beams : int
|
|
751
|
+
Beam search width (1 = greedy decoding).
|
|
752
|
+
temperature : float
|
|
753
|
+
Sampling temperature for generation.
|
|
754
|
+
top_p : float
|
|
755
|
+
Nucleus sampling probability cutoff.
|
|
756
|
+
learning_rate : float
|
|
757
|
+
Learning rate for AdamW optimizer.
|
|
758
|
+
batch_size : int
|
|
759
|
+
Batch size for training (typically smaller for seq2seq due to memory).
|
|
760
|
+
num_epochs : int
|
|
761
|
+
Number of training epochs.
|
|
762
|
+
device : Literal["cpu", "cuda", "mps"]
|
|
763
|
+
Device to train on.
|
|
764
|
+
lora_rank : int
|
|
765
|
+
LoRA rank r for low-rank decomposition (typical: 4-16).
|
|
766
|
+
lora_alpha : float
|
|
767
|
+
LoRA scaling factor α (typically 2*rank).
|
|
768
|
+
lora_dropout : float
|
|
769
|
+
Dropout probability for LoRA layers.
|
|
770
|
+
lora_target_modules : list[str]
|
|
771
|
+
Attention modules to apply LoRA (e.g., ["q_proj", "v_proj"]).
|
|
772
|
+
eval_metric : Literal["exact_match", "token_accuracy", "bleu"]
|
|
773
|
+
Evaluation metric for generation quality.
|
|
774
|
+
mixed_effects : MixedEffectsConfig
|
|
775
|
+
Mixed effects configuration for participant-level modeling.
|
|
776
|
+
|
|
777
|
+
Examples
|
|
778
|
+
--------
|
|
779
|
+
>>> config = FreeTextModelConfig()
|
|
780
|
+
>>> config.model_name
|
|
781
|
+
't5-base'
|
|
782
|
+
>>> config.lora_rank
|
|
783
|
+
8
|
|
784
|
+
>>> config.mixed_effects.mode
|
|
785
|
+
'fixed'
|
|
786
|
+
|
|
787
|
+
>>> # With random slopes (LoRA)
|
|
788
|
+
>>> config = FreeTextModelConfig(
|
|
789
|
+
... mixed_effects=MixedEffectsConfig(mode="random_slopes"),
|
|
790
|
+
... lora_rank=8,
|
|
791
|
+
... lora_alpha=16.0
|
|
792
|
+
... )
|
|
793
|
+
"""
|
|
794
|
+
|
|
795
|
+
model_name: str = Field(
|
|
796
|
+
default="t5-base",
|
|
797
|
+
description="HuggingFace seq2seq model identifier",
|
|
798
|
+
)
|
|
799
|
+
max_input_length: int = Field(
|
|
800
|
+
default=128,
|
|
801
|
+
description="Maximum input sequence length",
|
|
802
|
+
gt=0,
|
|
803
|
+
)
|
|
804
|
+
max_output_length: int = Field(
|
|
805
|
+
default=64,
|
|
806
|
+
description="Maximum output sequence length",
|
|
807
|
+
gt=0,
|
|
808
|
+
)
|
|
809
|
+
num_beams: int = Field(
|
|
810
|
+
default=4,
|
|
811
|
+
description="Beam search width (1 = greedy)",
|
|
812
|
+
gt=0,
|
|
813
|
+
)
|
|
814
|
+
temperature: float = Field(
|
|
815
|
+
default=1.0,
|
|
816
|
+
description="Sampling temperature",
|
|
817
|
+
gt=0.0,
|
|
818
|
+
)
|
|
819
|
+
top_p: float = Field(
|
|
820
|
+
default=0.9,
|
|
821
|
+
description="Nucleus sampling probability cutoff",
|
|
822
|
+
ge=0.0,
|
|
823
|
+
le=1.0,
|
|
824
|
+
)
|
|
825
|
+
learning_rate: float = Field(
|
|
826
|
+
default=2e-5,
|
|
827
|
+
description="Learning rate for AdamW optimizer",
|
|
828
|
+
gt=0,
|
|
829
|
+
)
|
|
830
|
+
batch_size: int = Field(
|
|
831
|
+
default=8,
|
|
832
|
+
description="Batch size for training",
|
|
833
|
+
gt=0,
|
|
834
|
+
)
|
|
835
|
+
num_epochs: int = Field(
|
|
836
|
+
default=3,
|
|
837
|
+
description="Number of training epochs",
|
|
838
|
+
gt=0,
|
|
839
|
+
)
|
|
840
|
+
device: Literal["cpu", "cuda", "mps"] = Field(
|
|
841
|
+
default="cpu",
|
|
842
|
+
description="Device to train on",
|
|
843
|
+
)
|
|
844
|
+
lora_rank: int = Field(
|
|
845
|
+
default=8,
|
|
846
|
+
description="LoRA rank r for low-rank decomposition",
|
|
847
|
+
gt=0,
|
|
848
|
+
)
|
|
849
|
+
lora_alpha: float = Field(
|
|
850
|
+
default=16.0,
|
|
851
|
+
description="LoRA scaling factor α",
|
|
852
|
+
gt=0,
|
|
853
|
+
)
|
|
854
|
+
lora_dropout: float = Field(
|
|
855
|
+
default=0.1,
|
|
856
|
+
description="Dropout probability for LoRA layers",
|
|
857
|
+
ge=0.0,
|
|
858
|
+
lt=1.0,
|
|
859
|
+
)
|
|
860
|
+
lora_target_modules: list[str] = Field(
|
|
861
|
+
default=["q", "v"],
|
|
862
|
+
description="Attention modules to apply LoRA to",
|
|
863
|
+
)
|
|
864
|
+
eval_metric: Literal["exact_match", "token_accuracy", "bleu"] = Field(
|
|
865
|
+
default="exact_match",
|
|
866
|
+
description="Evaluation metric for generation quality",
|
|
867
|
+
)
|
|
868
|
+
mixed_effects: MixedEffectsConfig = Field(
|
|
869
|
+
default_factory=MixedEffectsConfig,
|
|
870
|
+
description="Mixed effects configuration for participant-level modeling",
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
class ClozeModelConfig(BaseModel):
|
|
875
|
+
"""Configuration for cloze (fill-in-the-blank) models with GLMM support.
|
|
876
|
+
|
|
877
|
+
Implements masked language modeling with participant-level random effects for
|
|
878
|
+
predicting tokens at unfilled slots in partially-filled templates.
|
|
879
|
+
|
|
880
|
+
Parameters
|
|
881
|
+
----------
|
|
882
|
+
model_name : str
|
|
883
|
+
HuggingFace masked LM model identifier.
|
|
884
|
+
Examples: "bert-base-uncased", "roberta-base".
|
|
885
|
+
max_length : int
|
|
886
|
+
Maximum sequence length for tokenization.
|
|
887
|
+
learning_rate : float
|
|
888
|
+
Learning rate for AdamW optimizer.
|
|
889
|
+
batch_size : int
|
|
890
|
+
Batch size for training.
|
|
891
|
+
num_epochs : int
|
|
892
|
+
Number of training epochs.
|
|
893
|
+
device : Literal["cpu", "cuda", "mps"]
|
|
894
|
+
Device to train on.
|
|
895
|
+
mask_token : str
|
|
896
|
+
Token used for masking (model-specific, e.g., "[MASK]" for BERT).
|
|
897
|
+
eval_metric : Literal["exact_match", "token_accuracy"]
|
|
898
|
+
Evaluation metric for masked token prediction.
|
|
899
|
+
mixed_effects : MixedEffectsConfig
|
|
900
|
+
Mixed effects configuration for participant-level modeling.
|
|
901
|
+
|
|
902
|
+
Examples
|
|
903
|
+
--------
|
|
904
|
+
>>> config = ClozeModelConfig()
|
|
905
|
+
>>> config.model_name
|
|
906
|
+
'bert-base-uncased'
|
|
907
|
+
>>> config.mask_token
|
|
908
|
+
'[MASK]'
|
|
909
|
+
>>> config.mixed_effects.mode
|
|
910
|
+
'fixed'
|
|
911
|
+
|
|
912
|
+
>>> # With random intercepts
|
|
913
|
+
>>> config = ClozeModelConfig(
|
|
914
|
+
... mixed_effects=MixedEffectsConfig(mode="random_intercepts"),
|
|
915
|
+
... num_epochs=5
|
|
916
|
+
... )
|
|
917
|
+
"""
|
|
918
|
+
|
|
919
|
+
model_name: str = Field(
|
|
920
|
+
default="bert-base-uncased",
|
|
921
|
+
description="HuggingFace masked LM model identifier",
|
|
922
|
+
)
|
|
923
|
+
max_length: int = Field(
|
|
924
|
+
default=128,
|
|
925
|
+
description="Maximum sequence length for tokenization",
|
|
926
|
+
gt=0,
|
|
927
|
+
)
|
|
928
|
+
learning_rate: float = Field(
|
|
929
|
+
default=2e-5,
|
|
930
|
+
description="Learning rate for AdamW optimizer",
|
|
931
|
+
gt=0,
|
|
932
|
+
)
|
|
933
|
+
batch_size: int = Field(
|
|
934
|
+
default=16,
|
|
935
|
+
description="Batch size for training",
|
|
936
|
+
gt=0,
|
|
937
|
+
)
|
|
938
|
+
num_epochs: int = Field(
|
|
939
|
+
default=3,
|
|
940
|
+
description="Number of training epochs",
|
|
941
|
+
gt=0,
|
|
942
|
+
)
|
|
943
|
+
device: Literal["cpu", "cuda", "mps"] = Field(
|
|
944
|
+
default="cpu",
|
|
945
|
+
description="Device to train on",
|
|
946
|
+
)
|
|
947
|
+
mask_token: str = Field(
|
|
948
|
+
default="[MASK]",
|
|
949
|
+
description="Token used for masking (model-specific)",
|
|
950
|
+
)
|
|
951
|
+
eval_metric: Literal["exact_match", "token_accuracy"] = Field(
|
|
952
|
+
default="exact_match",
|
|
953
|
+
description="Evaluation metric for masked token prediction",
|
|
954
|
+
)
|
|
955
|
+
mixed_effects: MixedEffectsConfig = Field(
|
|
956
|
+
default_factory=MixedEffectsConfig,
|
|
957
|
+
description="Mixed effects configuration for participant-level modeling",
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
class ActiveLearningConfig(BaseModel):
|
|
962
|
+
"""Configuration for active learning infrastructure.
|
|
963
|
+
|
|
964
|
+
Reflects the bead/active_learning/ module structure:
|
|
965
|
+
- models: Active learning models (ForcedChoiceModel, etc.)
|
|
966
|
+
- trainers: Training infrastructure (HuggingFace, Lightning)
|
|
967
|
+
- loop: Active learning loop orchestration
|
|
968
|
+
- selection: Item selection strategies (uncertainty sampling, etc.)
|
|
969
|
+
|
|
970
|
+
Parameters
|
|
971
|
+
----------
|
|
972
|
+
forced_choice_model : ForcedChoiceModelConfig
|
|
973
|
+
Configuration for forced choice models.
|
|
974
|
+
trainer : TrainerConfig
|
|
975
|
+
Configuration for trainers (HuggingFace, Lightning).
|
|
976
|
+
loop : ActiveLearningLoopConfig
|
|
977
|
+
Configuration for active learning loop.
|
|
978
|
+
uncertainty_sampler : UncertaintySamplerConfig
|
|
979
|
+
Configuration for uncertainty sampling strategies.
|
|
980
|
+
|
|
981
|
+
Examples
|
|
982
|
+
--------
|
|
983
|
+
>>> config = ActiveLearningConfig()
|
|
984
|
+
>>> config.forced_choice_model.model_name
|
|
985
|
+
'bert-base-uncased'
|
|
986
|
+
>>> config.trainer.trainer_type
|
|
987
|
+
'huggingface'
|
|
988
|
+
>>> config.loop.max_iterations
|
|
989
|
+
10
|
|
990
|
+
>>> config.uncertainty_sampler.method
|
|
991
|
+
'entropy'
|
|
992
|
+
"""
|
|
993
|
+
|
|
994
|
+
forced_choice_model: ForcedChoiceModelConfig = Field(
|
|
995
|
+
default_factory=ForcedChoiceModelConfig,
|
|
996
|
+
description="Forced choice model configuration",
|
|
997
|
+
)
|
|
998
|
+
trainer: TrainerConfig = Field(
|
|
999
|
+
default_factory=TrainerConfig,
|
|
1000
|
+
description="Trainer configuration",
|
|
1001
|
+
)
|
|
1002
|
+
loop: ActiveLearningLoopConfig = Field(
|
|
1003
|
+
default_factory=ActiveLearningLoopConfig,
|
|
1004
|
+
description="Active learning loop configuration",
|
|
1005
|
+
)
|
|
1006
|
+
uncertainty_sampler: UncertaintySamplerConfig = Field(
|
|
1007
|
+
default_factory=UncertaintySamplerConfig,
|
|
1008
|
+
description="Uncertainty sampler configuration",
|
|
1009
|
+
)
|