bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""Simulation configuration models for the bead package."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NoiseModelConfig(BaseModel):
|
|
12
|
+
"""Configuration for noise model in simulated judgments.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
noise_type : Literal["temperature", "systematic", "random", "none"]
|
|
17
|
+
Type of noise to apply.
|
|
18
|
+
temperature : float
|
|
19
|
+
Temperature for scaling (higher = more random). Default: 1.0.
|
|
20
|
+
bias_strength : float
|
|
21
|
+
Strength of systematic biases (0.0-1.0). Default: 0.0.
|
|
22
|
+
bias_type : str | None
|
|
23
|
+
Type of bias ("length", "frequency", "position"). Default: None.
|
|
24
|
+
random_noise_stddev : float
|
|
25
|
+
Standard deviation for random noise. Default: 0.0.
|
|
26
|
+
|
|
27
|
+
Examples
|
|
28
|
+
--------
|
|
29
|
+
>>> # Temperature-scaled decisions (more random)
|
|
30
|
+
>>> config = NoiseModelConfig(noise_type="temperature", temperature=2.0)
|
|
31
|
+
>>>
|
|
32
|
+
>>> # Systematic length bias (prefer shorter)
|
|
33
|
+
>>> config = NoiseModelConfig(
|
|
34
|
+
... noise_type="systematic",
|
|
35
|
+
... bias_strength=0.3,
|
|
36
|
+
... bias_type="length"
|
|
37
|
+
... )
|
|
38
|
+
>>>
|
|
39
|
+
>>> # Random noise injection
|
|
40
|
+
>>> config = NoiseModelConfig(
|
|
41
|
+
... noise_type="random",
|
|
42
|
+
... random_noise_stddev=0.1
|
|
43
|
+
... )
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
noise_type: Literal["temperature", "systematic", "random", "none"] = Field(
|
|
47
|
+
default="temperature",
|
|
48
|
+
description="Type of noise model",
|
|
49
|
+
)
|
|
50
|
+
temperature: float = Field(
|
|
51
|
+
default=1.0,
|
|
52
|
+
ge=0.01,
|
|
53
|
+
le=10.0,
|
|
54
|
+
description="Temperature for scaling decisions",
|
|
55
|
+
)
|
|
56
|
+
bias_strength: float = Field(
|
|
57
|
+
default=0.0,
|
|
58
|
+
ge=0.0,
|
|
59
|
+
le=1.0,
|
|
60
|
+
description="Strength of systematic biases",
|
|
61
|
+
)
|
|
62
|
+
bias_type: str | None = Field(
|
|
63
|
+
default=None,
|
|
64
|
+
description="Type of systematic bias",
|
|
65
|
+
)
|
|
66
|
+
random_noise_stddev: float = Field(
|
|
67
|
+
default=0.0,
|
|
68
|
+
ge=0.0,
|
|
69
|
+
description="Standard deviation for random noise",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class SimulatedAnnotatorConfig(BaseModel):
|
|
74
|
+
"""Configuration for simulated annotator.
|
|
75
|
+
|
|
76
|
+
Attributes
|
|
77
|
+
----------
|
|
78
|
+
strategy : Literal["lm_score", "distance", "random", "oracle", "dsl"]
|
|
79
|
+
Base strategy for generating judgments.
|
|
80
|
+
noise_model : NoiseModelConfig
|
|
81
|
+
Noise model configuration.
|
|
82
|
+
dsl_expression : str | None
|
|
83
|
+
Custom DSL expression for simulation logic.
|
|
84
|
+
random_state : int | None
|
|
85
|
+
Random seed for reproducibility.
|
|
86
|
+
model_output_key : str
|
|
87
|
+
Key to extract from Item.model_outputs. Default: "lm_score".
|
|
88
|
+
fallback_to_random : bool
|
|
89
|
+
Whether to fallback to random if model outputs missing. Default: True.
|
|
90
|
+
|
|
91
|
+
Examples
|
|
92
|
+
--------
|
|
93
|
+
>>> # LM score-based with temperature
|
|
94
|
+
>>> config = SimulatedAnnotatorConfig(
|
|
95
|
+
... strategy="lm_score",
|
|
96
|
+
... noise_model=NoiseModelConfig(noise_type="temperature", temperature=1.5),
|
|
97
|
+
... random_state=42
|
|
98
|
+
... )
|
|
99
|
+
>>>
|
|
100
|
+
>>> # Distance-based with embeddings
|
|
101
|
+
>>> config = SimulatedAnnotatorConfig(
|
|
102
|
+
... strategy="distance",
|
|
103
|
+
... model_output_key="embedding",
|
|
104
|
+
... noise_model=NoiseModelConfig(noise_type="none")
|
|
105
|
+
... )
|
|
106
|
+
>>>
|
|
107
|
+
>>> # Custom DSL logic
|
|
108
|
+
>>> config = SimulatedAnnotatorConfig(
|
|
109
|
+
... strategy="dsl",
|
|
110
|
+
... dsl_expression="sample_categorical(softmax(model_scores) / temperature)",
|
|
111
|
+
... noise_model=NoiseModelConfig(noise_type="temperature", temperature=1.0)
|
|
112
|
+
... )
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
strategy: Literal["lm_score", "distance", "random", "oracle", "dsl"] = Field(
|
|
116
|
+
default="lm_score",
|
|
117
|
+
description="Base simulation strategy",
|
|
118
|
+
)
|
|
119
|
+
noise_model: NoiseModelConfig = Field(
|
|
120
|
+
default_factory=NoiseModelConfig,
|
|
121
|
+
description="Noise model configuration",
|
|
122
|
+
)
|
|
123
|
+
dsl_expression: str | None = Field(
|
|
124
|
+
default=None,
|
|
125
|
+
description="Custom DSL expression for simulation",
|
|
126
|
+
)
|
|
127
|
+
random_state: int | None = Field(
|
|
128
|
+
default=None,
|
|
129
|
+
description="Random seed for reproducibility",
|
|
130
|
+
)
|
|
131
|
+
model_output_key: str = Field(
|
|
132
|
+
default="lm_score",
|
|
133
|
+
description="Key to extract from model outputs",
|
|
134
|
+
)
|
|
135
|
+
fallback_to_random: bool = Field(
|
|
136
|
+
default=True,
|
|
137
|
+
description="Fallback to random if model outputs missing",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class SimulationRunnerConfig(BaseModel):
|
|
142
|
+
"""Configuration for simulation runner.
|
|
143
|
+
|
|
144
|
+
Attributes
|
|
145
|
+
----------
|
|
146
|
+
annotator_configs : list[SimulatedAnnotatorConfig]
|
|
147
|
+
List of annotator configurations (for multi-annotator simulation).
|
|
148
|
+
n_annotators : int
|
|
149
|
+
Number of simulated annotators. Default: 1.
|
|
150
|
+
inter_annotator_correlation : float | None
|
|
151
|
+
Desired correlation between annotators (0.0-1.0). Default: None (independent).
|
|
152
|
+
output_format : Literal["dict", "dataframe", "jsonl"]
|
|
153
|
+
Output format for simulation results. Default: "dict".
|
|
154
|
+
save_path : Path | None
|
|
155
|
+
Path to save simulation results. Default: None.
|
|
156
|
+
|
|
157
|
+
Examples
|
|
158
|
+
--------
|
|
159
|
+
>>> # Single annotator
|
|
160
|
+
>>> config = SimulationRunnerConfig(
|
|
161
|
+
... annotator_configs=[SimulatedAnnotatorConfig(strategy="lm_score")],
|
|
162
|
+
... n_annotators=1
|
|
163
|
+
... )
|
|
164
|
+
>>>
|
|
165
|
+
>>> # Multiple independent annotators
|
|
166
|
+
>>> config = SimulationRunnerConfig(
|
|
167
|
+
... annotator_configs=[
|
|
168
|
+
... SimulatedAnnotatorConfig(strategy="lm_score", random_state=1),
|
|
169
|
+
... SimulatedAnnotatorConfig(strategy="lm_score", random_state=2),
|
|
170
|
+
... SimulatedAnnotatorConfig(strategy="lm_score", random_state=3)
|
|
171
|
+
... ],
|
|
172
|
+
... n_annotators=3
|
|
173
|
+
... )
|
|
174
|
+
>>>
|
|
175
|
+
>>> # Correlated annotators
|
|
176
|
+
>>> config = SimulationRunnerConfig(
|
|
177
|
+
... annotator_configs=[SimulatedAnnotatorConfig(strategy="lm_score")],
|
|
178
|
+
... n_annotators=5,
|
|
179
|
+
... inter_annotator_correlation=0.7 # 70% agreement
|
|
180
|
+
... )
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
annotator_configs: list[SimulatedAnnotatorConfig] = Field(
|
|
184
|
+
default_factory=lambda: [SimulatedAnnotatorConfig()],
|
|
185
|
+
description="Annotator configurations",
|
|
186
|
+
)
|
|
187
|
+
n_annotators: int = Field(
|
|
188
|
+
default=1,
|
|
189
|
+
ge=1,
|
|
190
|
+
le=100,
|
|
191
|
+
description="Number of simulated annotators",
|
|
192
|
+
)
|
|
193
|
+
inter_annotator_correlation: float | None = Field(
|
|
194
|
+
default=None,
|
|
195
|
+
ge=0.0,
|
|
196
|
+
le=1.0,
|
|
197
|
+
description="Inter-annotator correlation",
|
|
198
|
+
)
|
|
199
|
+
output_format: Literal["dict", "dataframe", "jsonl"] = Field(
|
|
200
|
+
default="dict",
|
|
201
|
+
description="Output format",
|
|
202
|
+
)
|
|
203
|
+
save_path: Path | None = Field(
|
|
204
|
+
default=None,
|
|
205
|
+
description="Path to save results",
|
|
206
|
+
)
|
bead/config/template.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
"""Template configuration models for the bead package."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SlotStrategyConfig(BaseModel):
|
|
12
|
+
"""Configuration for a single slot's filling strategy.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
strategy
|
|
17
|
+
Filling strategy for this slot. Must be one of "exhaustive",
|
|
18
|
+
"random", "stratified", or "mlm".
|
|
19
|
+
sample_size
|
|
20
|
+
Sample size for random or stratified strategies. Only used when
|
|
21
|
+
strategy is "random" or "stratified".
|
|
22
|
+
stratify_by
|
|
23
|
+
Feature name to stratify by. Only used when strategy is "stratified".
|
|
24
|
+
beam_size
|
|
25
|
+
Beam size for MLM strategy. Only used when strategy is "mlm".
|
|
26
|
+
|
|
27
|
+
Examples
|
|
28
|
+
--------
|
|
29
|
+
>>> config = SlotStrategyConfig(strategy="exhaustive")
|
|
30
|
+
>>> config.strategy
|
|
31
|
+
'exhaustive'
|
|
32
|
+
>>> config_random = SlotStrategyConfig(strategy="random", sample_size=100)
|
|
33
|
+
>>> config_random.sample_size
|
|
34
|
+
100
|
|
35
|
+
>>> config_stratified = SlotStrategyConfig(
|
|
36
|
+
... strategy="stratified", sample_size=50, stratify_by="pos"
|
|
37
|
+
... )
|
|
38
|
+
>>> config_stratified.stratify_by
|
|
39
|
+
'pos'
|
|
40
|
+
>>> config_mlm = SlotStrategyConfig(strategy="mlm", beam_size=10)
|
|
41
|
+
>>> config_mlm.beam_size
|
|
42
|
+
10
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
strategy: Literal["exhaustive", "random", "stratified", "mlm"] = Field(
|
|
46
|
+
..., description="Filling strategy for this slot"
|
|
47
|
+
)
|
|
48
|
+
sample_size: int | None = Field(
|
|
49
|
+
default=None, description="Sample size for random/stratified"
|
|
50
|
+
)
|
|
51
|
+
stratify_by: str | None = Field(default=None, description="Feature to stratify by")
|
|
52
|
+
beam_size: int | None = Field(default=None, description="Beam size for MLM")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TemplateConfig(BaseModel):
|
|
56
|
+
"""Configuration for template filling.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
filling_strategy : str
|
|
61
|
+
Strategy name for filling templates
|
|
62
|
+
("exhaustive", "random", "stratified", "mlm", "mixed").
|
|
63
|
+
batch_size : int
|
|
64
|
+
Batch size for filling operations.
|
|
65
|
+
max_combinations : int | None
|
|
66
|
+
Maximum combinations to generate.
|
|
67
|
+
random_seed : int | None
|
|
68
|
+
Random seed for reproducibility.
|
|
69
|
+
stream_mode : bool
|
|
70
|
+
Use streaming for large templates.
|
|
71
|
+
use_csp_solver : bool
|
|
72
|
+
Use CSP solver for templates with multi-slot constraints.
|
|
73
|
+
mlm_model_name : str | None
|
|
74
|
+
HuggingFace model name for MLM filling.
|
|
75
|
+
mlm_beam_size : int
|
|
76
|
+
Beam search width for MLM strategy.
|
|
77
|
+
mlm_fill_direction : str
|
|
78
|
+
Direction for filling slots in MLM strategy.
|
|
79
|
+
mlm_custom_order : list[int] | None
|
|
80
|
+
Custom slot fill order for MLM strategy.
|
|
81
|
+
mlm_top_k : int
|
|
82
|
+
Number of top candidates per slot in MLM.
|
|
83
|
+
mlm_device : str
|
|
84
|
+
Device for MLM inference.
|
|
85
|
+
mlm_cache_enabled : bool
|
|
86
|
+
Enable content-addressable caching for MLM predictions.
|
|
87
|
+
mlm_cache_dir : Path | None
|
|
88
|
+
Directory for MLM prediction cache.
|
|
89
|
+
slot_strategies : dict[str, SlotStrategyConfig] | None
|
|
90
|
+
Per-slot strategy configuration for mixed filling.
|
|
91
|
+
Maps slot names to SlotStrategyConfig instances.
|
|
92
|
+
|
|
93
|
+
Examples
|
|
94
|
+
--------
|
|
95
|
+
>>> config = TemplateConfig()
|
|
96
|
+
>>> config.filling_strategy
|
|
97
|
+
'exhaustive'
|
|
98
|
+
>>> config.batch_size
|
|
99
|
+
1000
|
|
100
|
+
>>> # MLM configuration
|
|
101
|
+
>>> config_mlm = TemplateConfig(
|
|
102
|
+
... filling_strategy="mlm", mlm_model_name="bert-base-uncased"
|
|
103
|
+
... )
|
|
104
|
+
>>> config_mlm.mlm_beam_size
|
|
105
|
+
5
|
|
106
|
+
>>> # Mixed strategy configuration
|
|
107
|
+
>>> config_mixed = TemplateConfig(
|
|
108
|
+
... filling_strategy="mixed",
|
|
109
|
+
... mlm_model_name="bert-base-uncased",
|
|
110
|
+
... slot_strategies={
|
|
111
|
+
... "noun": SlotStrategyConfig(strategy="exhaustive"),
|
|
112
|
+
... "verb": SlotStrategyConfig(strategy="exhaustive"),
|
|
113
|
+
... "adjective": SlotStrategyConfig(strategy="mlm", beam_size=10)
|
|
114
|
+
... }
|
|
115
|
+
... )
|
|
116
|
+
>>> config_mixed.slot_strategies["noun"].strategy
|
|
117
|
+
'exhaustive'
|
|
118
|
+
>>> config_mixed.slot_strategies["adjective"].beam_size
|
|
119
|
+
10
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
filling_strategy: Literal["exhaustive", "random", "stratified", "mlm", "mixed"] = (
|
|
123
|
+
Field(default="exhaustive", description="Strategy for filling templates")
|
|
124
|
+
)
|
|
125
|
+
batch_size: int = Field(default=1000, description="Batch size for filling", gt=0)
|
|
126
|
+
max_combinations: int | None = Field(
|
|
127
|
+
default=None, description="Max combinations to generate"
|
|
128
|
+
)
|
|
129
|
+
random_seed: int | None = Field(
|
|
130
|
+
default=None, description="Random seed for reproducibility"
|
|
131
|
+
)
|
|
132
|
+
stream_mode: bool = Field(
|
|
133
|
+
default=False, description="Use streaming for large templates"
|
|
134
|
+
)
|
|
135
|
+
use_csp_solver: bool = Field(
|
|
136
|
+
default=False,
|
|
137
|
+
description="Use CSP solver for templates with multi-slot constraints",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# MLM-specific settings (model, beam size, fill direction)
|
|
141
|
+
mlm_model_name: str | None = Field(
|
|
142
|
+
default=None, description="HuggingFace model name for MLM filling"
|
|
143
|
+
)
|
|
144
|
+
mlm_beam_size: int = Field(
|
|
145
|
+
default=5, description="Beam search width for MLM strategy", gt=0
|
|
146
|
+
)
|
|
147
|
+
mlm_fill_direction: Literal[
|
|
148
|
+
"left_to_right", "right_to_left", "inside_out", "outside_in", "custom"
|
|
149
|
+
] = Field(
|
|
150
|
+
default="left_to_right",
|
|
151
|
+
description="Direction for filling slots in MLM strategy",
|
|
152
|
+
)
|
|
153
|
+
mlm_custom_order: list[int] | None = Field(
|
|
154
|
+
default=None, description="Custom slot fill order for MLM strategy"
|
|
155
|
+
)
|
|
156
|
+
mlm_top_k: int = Field(
|
|
157
|
+
default=20, description="Number of top candidates per slot in MLM", gt=0
|
|
158
|
+
)
|
|
159
|
+
mlm_device: str = Field(default="cpu", description="Device for MLM inference")
|
|
160
|
+
mlm_cache_enabled: bool = Field(
|
|
161
|
+
default=True, description="Enable caching for MLM predictions"
|
|
162
|
+
)
|
|
163
|
+
mlm_cache_dir: Path | None = Field(
|
|
164
|
+
default=None, description="Directory for MLM prediction cache"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# mixed strategy settings
|
|
168
|
+
slot_strategies: dict[str, SlotStrategyConfig] | None = Field(
|
|
169
|
+
default=None,
|
|
170
|
+
description="Per-slot strategy configuration for mixed filling. "
|
|
171
|
+
"Maps slot names to SlotStrategyConfig instances.",
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
@field_validator("max_combinations")
|
|
175
|
+
@classmethod
|
|
176
|
+
def validate_max_combinations(cls, v: int | None) -> int | None:
|
|
177
|
+
"""Validate max_combinations is positive.
|
|
178
|
+
|
|
179
|
+
Parameters
|
|
180
|
+
----------
|
|
181
|
+
v : int | None
|
|
182
|
+
Max combinations value.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
int | None
|
|
187
|
+
Validated value.
|
|
188
|
+
|
|
189
|
+
Raises
|
|
190
|
+
------
|
|
191
|
+
ValueError
|
|
192
|
+
If value is not positive.
|
|
193
|
+
"""
|
|
194
|
+
if v is not None and v <= 0:
|
|
195
|
+
msg = f"max_combinations must be positive, got {v}"
|
|
196
|
+
raise ValueError(msg)
|
|
197
|
+
return v
|
|
198
|
+
|
|
199
|
+
@model_validator(mode="after")
|
|
200
|
+
def validate_mlm_config(self) -> TemplateConfig:
|
|
201
|
+
"""Validate MLM configuration is consistent.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
TemplateConfig
|
|
206
|
+
Validated config.
|
|
207
|
+
|
|
208
|
+
Raises
|
|
209
|
+
------
|
|
210
|
+
ValueError
|
|
211
|
+
If MLM config is inconsistent.
|
|
212
|
+
"""
|
|
213
|
+
if self.filling_strategy == "mlm" and self.mlm_model_name is None:
|
|
214
|
+
msg = "mlm_model_name must be specified when filling_strategy is 'mlm'"
|
|
215
|
+
raise ValueError(msg)
|
|
216
|
+
|
|
217
|
+
if self.mlm_fill_direction == "custom" and self.mlm_custom_order is None:
|
|
218
|
+
msg = (
|
|
219
|
+
"mlm_custom_order must be specified when mlm_fill_direction is 'custom'"
|
|
220
|
+
)
|
|
221
|
+
raise ValueError(msg)
|
|
222
|
+
|
|
223
|
+
# validate mixed strategy configuration
|
|
224
|
+
if self.filling_strategy == "mixed" and self.slot_strategies is None:
|
|
225
|
+
msg = "slot_strategies must be specified when filling_strategy is 'mixed'"
|
|
226
|
+
raise ValueError(msg)
|
|
227
|
+
|
|
228
|
+
if self.slot_strategies is not None:
|
|
229
|
+
for slot_name, slot_config in self.slot_strategies.items():
|
|
230
|
+
# if MLM strategy is used for a slot, check model config is available
|
|
231
|
+
if slot_config.strategy == "mlm" and self.mlm_model_name is None:
|
|
232
|
+
msg = (
|
|
233
|
+
f"mlm_model_name must be specified when slot "
|
|
234
|
+
f"'{slot_name}' uses MLM"
|
|
235
|
+
)
|
|
236
|
+
raise ValueError(msg)
|
|
237
|
+
|
|
238
|
+
return self
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""Configuration validation utilities.
|
|
2
|
+
|
|
3
|
+
This module provides pre-flight validation for configuration objects,
|
|
4
|
+
checking for common issues before the configuration is used.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from bead.config.config import BeadConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def check_paths_exist(config: BeadConfig) -> list[str]:
|
|
13
|
+
"""Check that all configured paths exist or can be created.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
config : BeadConfig
|
|
18
|
+
Configuration to check.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
list[str]
|
|
23
|
+
List of path validation errors.
|
|
24
|
+
|
|
25
|
+
Examples
|
|
26
|
+
--------
|
|
27
|
+
>>> from bead.config import get_default_config
|
|
28
|
+
>>> config = get_default_config()
|
|
29
|
+
>>> errors = check_paths_exist(config)
|
|
30
|
+
>>> isinstance(errors, list)
|
|
31
|
+
True
|
|
32
|
+
"""
|
|
33
|
+
errors: list[str] = []
|
|
34
|
+
|
|
35
|
+
# check main paths if they should exist and are absolute
|
|
36
|
+
if config.paths.data_dir.is_absolute() and not config.paths.data_dir.exists():
|
|
37
|
+
errors.append(f"data_dir does not exist: {config.paths.data_dir}")
|
|
38
|
+
|
|
39
|
+
if config.paths.output_dir.is_absolute() and not config.paths.output_dir.exists():
|
|
40
|
+
errors.append(f"output_dir does not exist: {config.paths.output_dir}")
|
|
41
|
+
|
|
42
|
+
if config.paths.cache_dir.is_absolute() and not config.paths.cache_dir.exists():
|
|
43
|
+
errors.append(f"cache_dir does not exist: {config.paths.cache_dir}")
|
|
44
|
+
|
|
45
|
+
# check resource paths
|
|
46
|
+
if (
|
|
47
|
+
config.resources.lexicon_path is not None
|
|
48
|
+
and config.resources.lexicon_path.is_absolute()
|
|
49
|
+
and not config.resources.lexicon_path.exists()
|
|
50
|
+
):
|
|
51
|
+
errors.append(f"lexicon_path does not exist: {config.resources.lexicon_path}")
|
|
52
|
+
|
|
53
|
+
if (
|
|
54
|
+
config.resources.templates_path is not None
|
|
55
|
+
and config.resources.templates_path.is_absolute()
|
|
56
|
+
and not config.resources.templates_path.exists()
|
|
57
|
+
):
|
|
58
|
+
errors.append(
|
|
59
|
+
f"templates_path does not exist: {config.resources.templates_path}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if (
|
|
63
|
+
config.resources.constraints_path is not None
|
|
64
|
+
and config.resources.constraints_path.is_absolute()
|
|
65
|
+
and not config.resources.constraints_path.exists()
|
|
66
|
+
):
|
|
67
|
+
errors.append(
|
|
68
|
+
f"constraints_path does not exist: {config.resources.constraints_path}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# check training logging dir
|
|
72
|
+
if (
|
|
73
|
+
config.active_learning.trainer.logging_dir.is_absolute()
|
|
74
|
+
and not config.active_learning.trainer.logging_dir.exists()
|
|
75
|
+
):
|
|
76
|
+
errors.append(
|
|
77
|
+
f"logging_dir does not exist: {config.active_learning.trainer.logging_dir}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# check logging file parent directory
|
|
81
|
+
if (
|
|
82
|
+
config.logging.file is not None
|
|
83
|
+
and config.logging.file.is_absolute()
|
|
84
|
+
and not config.logging.file.parent.exists()
|
|
85
|
+
):
|
|
86
|
+
parent_dir = config.logging.file.parent
|
|
87
|
+
errors.append(f"logging file parent directory does not exist: {parent_dir}")
|
|
88
|
+
|
|
89
|
+
return errors
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def check_resource_compatibility(config: BeadConfig) -> list[str]:
|
|
93
|
+
"""Verify resources are compatible with templates.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
config : BeadConfig
|
|
98
|
+
Configuration to check.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
list[str]
|
|
103
|
+
List of resource compatibility errors.
|
|
104
|
+
"""
|
|
105
|
+
errors: list[str] = []
|
|
106
|
+
|
|
107
|
+
# check that if templates_path is specified, lexicon_path should also be specified
|
|
108
|
+
if (
|
|
109
|
+
config.resources.templates_path is not None
|
|
110
|
+
and config.resources.lexicon_path is None
|
|
111
|
+
):
|
|
112
|
+
errors.append(
|
|
113
|
+
"templates_path is specified but lexicon_path is not. "
|
|
114
|
+
"Templates require a lexicon."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return errors
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def check_model_configuration(config: BeadConfig) -> list[str]:
|
|
121
|
+
"""Verify model settings are valid.
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
config : BeadConfig
|
|
126
|
+
Configuration to check.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
list[str]
|
|
131
|
+
List of model configuration errors.
|
|
132
|
+
"""
|
|
133
|
+
try:
|
|
134
|
+
import torch # noqa: PLC0415
|
|
135
|
+
except ImportError:
|
|
136
|
+
torch = None # type: ignore[assignment]
|
|
137
|
+
|
|
138
|
+
errors: list[str] = []
|
|
139
|
+
|
|
140
|
+
# check CUDA availability if device is set to cuda
|
|
141
|
+
if config.items.model.device == "cuda":
|
|
142
|
+
if torch is None:
|
|
143
|
+
errors.append(
|
|
144
|
+
"Model device is set to 'cuda' but PyTorch is not installed. "
|
|
145
|
+
"Install PyTorch or set device to 'cpu'."
|
|
146
|
+
)
|
|
147
|
+
elif not torch.cuda.is_available(): # type: ignore[no-untyped-call]
|
|
148
|
+
errors.append(
|
|
149
|
+
"Model device is set to 'cuda' but CUDA is not available. "
|
|
150
|
+
"Set device to 'cpu' or install CUDA."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# check MPS availability if device is set to mps
|
|
154
|
+
if config.items.model.device == "mps":
|
|
155
|
+
if torch is None:
|
|
156
|
+
errors.append(
|
|
157
|
+
"Model device is set to 'mps' but PyTorch is not installed. "
|
|
158
|
+
"Install PyTorch or set device to 'cpu'."
|
|
159
|
+
)
|
|
160
|
+
elif not torch.backends.mps.is_available(): # type: ignore[no-untyped-call]
|
|
161
|
+
errors.append(
|
|
162
|
+
"Model device is set to 'mps' but MPS is not available. "
|
|
163
|
+
"Set device to 'cpu' or use a macOS system with MPS support."
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return errors
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def check_training_configuration(config: BeadConfig) -> list[str]:
|
|
170
|
+
"""Verify training settings are compatible.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
config : BeadConfig
|
|
175
|
+
Configuration to check.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
list[str]
|
|
180
|
+
List of training configuration errors.
|
|
181
|
+
"""
|
|
182
|
+
errors: list[str] = []
|
|
183
|
+
|
|
184
|
+
# check that batch size is positive
|
|
185
|
+
if config.active_learning.forced_choice_model.batch_size <= 0:
|
|
186
|
+
batch_size = config.active_learning.forced_choice_model.batch_size
|
|
187
|
+
errors.append(f"Training batch size must be positive, got {batch_size}")
|
|
188
|
+
|
|
189
|
+
# check that epochs is positive
|
|
190
|
+
if config.active_learning.trainer.epochs <= 0:
|
|
191
|
+
epochs = config.active_learning.trainer.epochs
|
|
192
|
+
errors.append(f"Training epochs must be positive, got {epochs}")
|
|
193
|
+
|
|
194
|
+
# check that learning rate is positive
|
|
195
|
+
if config.active_learning.forced_choice_model.learning_rate <= 0:
|
|
196
|
+
lr = config.active_learning.forced_choice_model.learning_rate
|
|
197
|
+
errors.append(f"Training learning rate must be positive, got {lr}")
|
|
198
|
+
|
|
199
|
+
return errors
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def check_deployment_configuration(config: BeadConfig) -> list[str]:
|
|
203
|
+
"""Verify deployment settings are valid.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
config : BeadConfig
|
|
208
|
+
Configuration to check.
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
list[str]
|
|
213
|
+
List of deployment configuration errors.
|
|
214
|
+
"""
|
|
215
|
+
errors: list[str] = []
|
|
216
|
+
|
|
217
|
+
# check jsPsych version format if platform is jspsych
|
|
218
|
+
if config.deployment.platform == "jspsych":
|
|
219
|
+
version = config.deployment.jspsych_version
|
|
220
|
+
if version is None: # type: ignore[reportUnnecessaryComparison]
|
|
221
|
+
errors.append("jsPsych platform requires jspsych_version to be specified")
|
|
222
|
+
elif not isinstance(version, str): # type: ignore[reportUnnecessaryIsInstance]
|
|
223
|
+
errors.append(
|
|
224
|
+
f"jspsych_version must be a string, got {type(version).__name__}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return errors
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def validate_config(config: BeadConfig) -> list[str]:
|
|
231
|
+
"""Perform pre-flight validation on configuration.
|
|
232
|
+
|
|
233
|
+
Checks:
|
|
234
|
+
- All paths exist (if absolute paths are specified)
|
|
235
|
+
- Resource paths exist (if specified)
|
|
236
|
+
- Model configurations are compatible
|
|
237
|
+
- Training configurations are valid
|
|
238
|
+
- No conflicting settings
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
config : BeadConfig
|
|
243
|
+
Configuration to validate.
|
|
244
|
+
|
|
245
|
+
Returns
|
|
246
|
+
-------
|
|
247
|
+
list[str]
|
|
248
|
+
List of validation errors. Empty if valid.
|
|
249
|
+
|
|
250
|
+
Examples
|
|
251
|
+
--------
|
|
252
|
+
>>> from bead.config import get_default_config
|
|
253
|
+
>>> config = get_default_config()
|
|
254
|
+
>>> errors = validate_config(config)
|
|
255
|
+
>>> len(errors)
|
|
256
|
+
0
|
|
257
|
+
"""
|
|
258
|
+
errors: list[str] = []
|
|
259
|
+
|
|
260
|
+
# run all validation checks
|
|
261
|
+
errors.extend(check_paths_exist(config))
|
|
262
|
+
errors.extend(check_resource_compatibility(config))
|
|
263
|
+
errors.extend(check_model_configuration(config))
|
|
264
|
+
errors.extend(check_training_configuration(config))
|
|
265
|
+
errors.extend(check_deployment_configuration(config))
|
|
266
|
+
|
|
267
|
+
return errors
|