bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,206 @@
1
+ """Simulation configuration models for the bead package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ class NoiseModelConfig(BaseModel):
12
+ """Configuration for noise model in simulated judgments.
13
+
14
+ Attributes
15
+ ----------
16
+ noise_type : Literal["temperature", "systematic", "random", "none"]
17
+ Type of noise to apply.
18
+ temperature : float
19
+ Temperature for scaling (higher = more random). Default: 1.0.
20
+ bias_strength : float
21
+ Strength of systematic biases (0.0-1.0). Default: 0.0.
22
+ bias_type : str | None
23
+ Type of bias ("length", "frequency", "position"). Default: None.
24
+ random_noise_stddev : float
25
+ Standard deviation for random noise. Default: 0.0.
26
+
27
+ Examples
28
+ --------
29
+ >>> # Temperature-scaled decisions (more random)
30
+ >>> config = NoiseModelConfig(noise_type="temperature", temperature=2.0)
31
+ >>>
32
+ >>> # Systematic length bias (prefer shorter)
33
+ >>> config = NoiseModelConfig(
34
+ ... noise_type="systematic",
35
+ ... bias_strength=0.3,
36
+ ... bias_type="length"
37
+ ... )
38
+ >>>
39
+ >>> # Random noise injection
40
+ >>> config = NoiseModelConfig(
41
+ ... noise_type="random",
42
+ ... random_noise_stddev=0.1
43
+ ... )
44
+ """
45
+
46
+ noise_type: Literal["temperature", "systematic", "random", "none"] = Field(
47
+ default="temperature",
48
+ description="Type of noise model",
49
+ )
50
+ temperature: float = Field(
51
+ default=1.0,
52
+ ge=0.01,
53
+ le=10.0,
54
+ description="Temperature for scaling decisions",
55
+ )
56
+ bias_strength: float = Field(
57
+ default=0.0,
58
+ ge=0.0,
59
+ le=1.0,
60
+ description="Strength of systematic biases",
61
+ )
62
+ bias_type: str | None = Field(
63
+ default=None,
64
+ description="Type of systematic bias",
65
+ )
66
+ random_noise_stddev: float = Field(
67
+ default=0.0,
68
+ ge=0.0,
69
+ description="Standard deviation for random noise",
70
+ )
71
+
72
+
73
+ class SimulatedAnnotatorConfig(BaseModel):
74
+ """Configuration for simulated annotator.
75
+
76
+ Attributes
77
+ ----------
78
+ strategy : Literal["lm_score", "distance", "random", "oracle", "dsl"]
79
+ Base strategy for generating judgments.
80
+ noise_model : NoiseModelConfig
81
+ Noise model configuration.
82
+ dsl_expression : str | None
83
+ Custom DSL expression for simulation logic.
84
+ random_state : int | None
85
+ Random seed for reproducibility.
86
+ model_output_key : str
87
+ Key to extract from Item.model_outputs. Default: "lm_score".
88
+ fallback_to_random : bool
89
+ Whether to fallback to random if model outputs missing. Default: True.
90
+
91
+ Examples
92
+ --------
93
+ >>> # LM score-based with temperature
94
+ >>> config = SimulatedAnnotatorConfig(
95
+ ... strategy="lm_score",
96
+ ... noise_model=NoiseModelConfig(noise_type="temperature", temperature=1.5),
97
+ ... random_state=42
98
+ ... )
99
+ >>>
100
+ >>> # Distance-based with embeddings
101
+ >>> config = SimulatedAnnotatorConfig(
102
+ ... strategy="distance",
103
+ ... model_output_key="embedding",
104
+ ... noise_model=NoiseModelConfig(noise_type="none")
105
+ ... )
106
+ >>>
107
+ >>> # Custom DSL logic
108
+ >>> config = SimulatedAnnotatorConfig(
109
+ ... strategy="dsl",
110
+ ... dsl_expression="sample_categorical(softmax(model_scores) / temperature)",
111
+ ... noise_model=NoiseModelConfig(noise_type="temperature", temperature=1.0)
112
+ ... )
113
+ """
114
+
115
+ strategy: Literal["lm_score", "distance", "random", "oracle", "dsl"] = Field(
116
+ default="lm_score",
117
+ description="Base simulation strategy",
118
+ )
119
+ noise_model: NoiseModelConfig = Field(
120
+ default_factory=NoiseModelConfig,
121
+ description="Noise model configuration",
122
+ )
123
+ dsl_expression: str | None = Field(
124
+ default=None,
125
+ description="Custom DSL expression for simulation",
126
+ )
127
+ random_state: int | None = Field(
128
+ default=None,
129
+ description="Random seed for reproducibility",
130
+ )
131
+ model_output_key: str = Field(
132
+ default="lm_score",
133
+ description="Key to extract from model outputs",
134
+ )
135
+ fallback_to_random: bool = Field(
136
+ default=True,
137
+ description="Fallback to random if model outputs missing",
138
+ )
139
+
140
+
141
+ class SimulationRunnerConfig(BaseModel):
142
+ """Configuration for simulation runner.
143
+
144
+ Attributes
145
+ ----------
146
+ annotator_configs : list[SimulatedAnnotatorConfig]
147
+ List of annotator configurations (for multi-annotator simulation).
148
+ n_annotators : int
149
+ Number of simulated annotators. Default: 1.
150
+ inter_annotator_correlation : float | None
151
+ Desired correlation between annotators (0.0-1.0). Default: None (independent).
152
+ output_format : Literal["dict", "dataframe", "jsonl"]
153
+ Output format for simulation results. Default: "dict".
154
+ save_path : Path | None
155
+ Path to save simulation results. Default: None.
156
+
157
+ Examples
158
+ --------
159
+ >>> # Single annotator
160
+ >>> config = SimulationRunnerConfig(
161
+ ... annotator_configs=[SimulatedAnnotatorConfig(strategy="lm_score")],
162
+ ... n_annotators=1
163
+ ... )
164
+ >>>
165
+ >>> # Multiple independent annotators
166
+ >>> config = SimulationRunnerConfig(
167
+ ... annotator_configs=[
168
+ ... SimulatedAnnotatorConfig(strategy="lm_score", random_state=1),
169
+ ... SimulatedAnnotatorConfig(strategy="lm_score", random_state=2),
170
+ ... SimulatedAnnotatorConfig(strategy="lm_score", random_state=3)
171
+ ... ],
172
+ ... n_annotators=3
173
+ ... )
174
+ >>>
175
+ >>> # Correlated annotators
176
+ >>> config = SimulationRunnerConfig(
177
+ ... annotator_configs=[SimulatedAnnotatorConfig(strategy="lm_score")],
178
+ ... n_annotators=5,
179
+ ... inter_annotator_correlation=0.7 # 70% agreement
180
+ ... )
181
+ """
182
+
183
+ annotator_configs: list[SimulatedAnnotatorConfig] = Field(
184
+ default_factory=lambda: [SimulatedAnnotatorConfig()],
185
+ description="Annotator configurations",
186
+ )
187
+ n_annotators: int = Field(
188
+ default=1,
189
+ ge=1,
190
+ le=100,
191
+ description="Number of simulated annotators",
192
+ )
193
+ inter_annotator_correlation: float | None = Field(
194
+ default=None,
195
+ ge=0.0,
196
+ le=1.0,
197
+ description="Inter-annotator correlation",
198
+ )
199
+ output_format: Literal["dict", "dataframe", "jsonl"] = Field(
200
+ default="dict",
201
+ description="Output format",
202
+ )
203
+ save_path: Path | None = Field(
204
+ default=None,
205
+ description="Path to save results",
206
+ )
@@ -0,0 +1,238 @@
1
+ """Template configuration models for the bead package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ from pydantic import BaseModel, Field, field_validator, model_validator
9
+
10
+
11
+ class SlotStrategyConfig(BaseModel):
12
+ """Configuration for a single slot's filling strategy.
13
+
14
+ Parameters
15
+ ----------
16
+ strategy
17
+ Filling strategy for this slot. Must be one of "exhaustive",
18
+ "random", "stratified", or "mlm".
19
+ sample_size
20
+ Sample size for random or stratified strategies. Only used when
21
+ strategy is "random" or "stratified".
22
+ stratify_by
23
+ Feature name to stratify by. Only used when strategy is "stratified".
24
+ beam_size
25
+ Beam size for MLM strategy. Only used when strategy is "mlm".
26
+
27
+ Examples
28
+ --------
29
+ >>> config = SlotStrategyConfig(strategy="exhaustive")
30
+ >>> config.strategy
31
+ 'exhaustive'
32
+ >>> config_random = SlotStrategyConfig(strategy="random", sample_size=100)
33
+ >>> config_random.sample_size
34
+ 100
35
+ >>> config_stratified = SlotStrategyConfig(
36
+ ... strategy="stratified", sample_size=50, stratify_by="pos"
37
+ ... )
38
+ >>> config_stratified.stratify_by
39
+ 'pos'
40
+ >>> config_mlm = SlotStrategyConfig(strategy="mlm", beam_size=10)
41
+ >>> config_mlm.beam_size
42
+ 10
43
+ """
44
+
45
+ strategy: Literal["exhaustive", "random", "stratified", "mlm"] = Field(
46
+ ..., description="Filling strategy for this slot"
47
+ )
48
+ sample_size: int | None = Field(
49
+ default=None, description="Sample size for random/stratified"
50
+ )
51
+ stratify_by: str | None = Field(default=None, description="Feature to stratify by")
52
+ beam_size: int | None = Field(default=None, description="Beam size for MLM")
53
+
54
+
55
+ class TemplateConfig(BaseModel):
56
+ """Configuration for template filling.
57
+
58
+ Parameters
59
+ ----------
60
+ filling_strategy : str
61
+ Strategy name for filling templates
62
+ ("exhaustive", "random", "stratified", "mlm", "mixed").
63
+ batch_size : int
64
+ Batch size for filling operations.
65
+ max_combinations : int | None
66
+ Maximum combinations to generate.
67
+ random_seed : int | None
68
+ Random seed for reproducibility.
69
+ stream_mode : bool
70
+ Use streaming for large templates.
71
+ use_csp_solver : bool
72
+ Use CSP solver for templates with multi-slot constraints.
73
+ mlm_model_name : str | None
74
+ HuggingFace model name for MLM filling.
75
+ mlm_beam_size : int
76
+ Beam search width for MLM strategy.
77
+ mlm_fill_direction : str
78
+ Direction for filling slots in MLM strategy.
79
+ mlm_custom_order : list[int] | None
80
+ Custom slot fill order for MLM strategy.
81
+ mlm_top_k : int
82
+ Number of top candidates per slot in MLM.
83
+ mlm_device : str
84
+ Device for MLM inference.
85
+ mlm_cache_enabled : bool
86
+ Enable content-addressable caching for MLM predictions.
87
+ mlm_cache_dir : Path | None
88
+ Directory for MLM prediction cache.
89
+ slot_strategies : dict[str, SlotStrategyConfig] | None
90
+ Per-slot strategy configuration for mixed filling.
91
+ Maps slot names to SlotStrategyConfig instances.
92
+
93
+ Examples
94
+ --------
95
+ >>> config = TemplateConfig()
96
+ >>> config.filling_strategy
97
+ 'exhaustive'
98
+ >>> config.batch_size
99
+ 1000
100
+ >>> # MLM configuration
101
+ >>> config_mlm = TemplateConfig(
102
+ ... filling_strategy="mlm", mlm_model_name="bert-base-uncased"
103
+ ... )
104
+ >>> config_mlm.mlm_beam_size
105
+ 5
106
+ >>> # Mixed strategy configuration
107
+ >>> config_mixed = TemplateConfig(
108
+ ... filling_strategy="mixed",
109
+ ... mlm_model_name="bert-base-uncased",
110
+ ... slot_strategies={
111
+ ... "noun": SlotStrategyConfig(strategy="exhaustive"),
112
+ ... "verb": SlotStrategyConfig(strategy="exhaustive"),
113
+ ... "adjective": SlotStrategyConfig(strategy="mlm", beam_size=10)
114
+ ... }
115
+ ... )
116
+ >>> config_mixed.slot_strategies["noun"].strategy
117
+ 'exhaustive'
118
+ >>> config_mixed.slot_strategies["adjective"].beam_size
119
+ 10
120
+ """
121
+
122
+ filling_strategy: Literal["exhaustive", "random", "stratified", "mlm", "mixed"] = (
123
+ Field(default="exhaustive", description="Strategy for filling templates")
124
+ )
125
+ batch_size: int = Field(default=1000, description="Batch size for filling", gt=0)
126
+ max_combinations: int | None = Field(
127
+ default=None, description="Max combinations to generate"
128
+ )
129
+ random_seed: int | None = Field(
130
+ default=None, description="Random seed for reproducibility"
131
+ )
132
+ stream_mode: bool = Field(
133
+ default=False, description="Use streaming for large templates"
134
+ )
135
+ use_csp_solver: bool = Field(
136
+ default=False,
137
+ description="Use CSP solver for templates with multi-slot constraints",
138
+ )
139
+
140
+ # MLM-specific settings (model, beam size, fill direction)
141
+ mlm_model_name: str | None = Field(
142
+ default=None, description="HuggingFace model name for MLM filling"
143
+ )
144
+ mlm_beam_size: int = Field(
145
+ default=5, description="Beam search width for MLM strategy", gt=0
146
+ )
147
+ mlm_fill_direction: Literal[
148
+ "left_to_right", "right_to_left", "inside_out", "outside_in", "custom"
149
+ ] = Field(
150
+ default="left_to_right",
151
+ description="Direction for filling slots in MLM strategy",
152
+ )
153
+ mlm_custom_order: list[int] | None = Field(
154
+ default=None, description="Custom slot fill order for MLM strategy"
155
+ )
156
+ mlm_top_k: int = Field(
157
+ default=20, description="Number of top candidates per slot in MLM", gt=0
158
+ )
159
+ mlm_device: str = Field(default="cpu", description="Device for MLM inference")
160
+ mlm_cache_enabled: bool = Field(
161
+ default=True, description="Enable caching for MLM predictions"
162
+ )
163
+ mlm_cache_dir: Path | None = Field(
164
+ default=None, description="Directory for MLM prediction cache"
165
+ )
166
+
167
+ # mixed strategy settings
168
+ slot_strategies: dict[str, SlotStrategyConfig] | None = Field(
169
+ default=None,
170
+ description="Per-slot strategy configuration for mixed filling. "
171
+ "Maps slot names to SlotStrategyConfig instances.",
172
+ )
173
+
174
+ @field_validator("max_combinations")
175
+ @classmethod
176
+ def validate_max_combinations(cls, v: int | None) -> int | None:
177
+ """Validate max_combinations is positive.
178
+
179
+ Parameters
180
+ ----------
181
+ v : int | None
182
+ Max combinations value.
183
+
184
+ Returns
185
+ -------
186
+ int | None
187
+ Validated value.
188
+
189
+ Raises
190
+ ------
191
+ ValueError
192
+ If value is not positive.
193
+ """
194
+ if v is not None and v <= 0:
195
+ msg = f"max_combinations must be positive, got {v}"
196
+ raise ValueError(msg)
197
+ return v
198
+
199
+ @model_validator(mode="after")
200
+ def validate_mlm_config(self) -> TemplateConfig:
201
+ """Validate MLM configuration is consistent.
202
+
203
+ Returns
204
+ -------
205
+ TemplateConfig
206
+ Validated config.
207
+
208
+ Raises
209
+ ------
210
+ ValueError
211
+ If MLM config is inconsistent.
212
+ """
213
+ if self.filling_strategy == "mlm" and self.mlm_model_name is None:
214
+ msg = "mlm_model_name must be specified when filling_strategy is 'mlm'"
215
+ raise ValueError(msg)
216
+
217
+ if self.mlm_fill_direction == "custom" and self.mlm_custom_order is None:
218
+ msg = (
219
+ "mlm_custom_order must be specified when mlm_fill_direction is 'custom'"
220
+ )
221
+ raise ValueError(msg)
222
+
223
+ # validate mixed strategy configuration
224
+ if self.filling_strategy == "mixed" and self.slot_strategies is None:
225
+ msg = "slot_strategies must be specified when filling_strategy is 'mixed'"
226
+ raise ValueError(msg)
227
+
228
+ if self.slot_strategies is not None:
229
+ for slot_name, slot_config in self.slot_strategies.items():
230
+ # if MLM strategy is used for a slot, check model config is available
231
+ if slot_config.strategy == "mlm" and self.mlm_model_name is None:
232
+ msg = (
233
+ f"mlm_model_name must be specified when slot "
234
+ f"'{slot_name}' uses MLM"
235
+ )
236
+ raise ValueError(msg)
237
+
238
+ return self
@@ -0,0 +1,267 @@
1
+ """Configuration validation utilities.
2
+
3
+ This module provides pre-flight validation for configuration objects,
4
+ checking for common issues before the configuration is used.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from bead.config.config import BeadConfig
10
+
11
+
12
+ def check_paths_exist(config: BeadConfig) -> list[str]:
13
+ """Check that all configured paths exist or can be created.
14
+
15
+ Parameters
16
+ ----------
17
+ config : BeadConfig
18
+ Configuration to check.
19
+
20
+ Returns
21
+ -------
22
+ list[str]
23
+ List of path validation errors.
24
+
25
+ Examples
26
+ --------
27
+ >>> from bead.config import get_default_config
28
+ >>> config = get_default_config()
29
+ >>> errors = check_paths_exist(config)
30
+ >>> isinstance(errors, list)
31
+ True
32
+ """
33
+ errors: list[str] = []
34
+
35
+ # check main paths if they should exist and are absolute
36
+ if config.paths.data_dir.is_absolute() and not config.paths.data_dir.exists():
37
+ errors.append(f"data_dir does not exist: {config.paths.data_dir}")
38
+
39
+ if config.paths.output_dir.is_absolute() and not config.paths.output_dir.exists():
40
+ errors.append(f"output_dir does not exist: {config.paths.output_dir}")
41
+
42
+ if config.paths.cache_dir.is_absolute() and not config.paths.cache_dir.exists():
43
+ errors.append(f"cache_dir does not exist: {config.paths.cache_dir}")
44
+
45
+ # check resource paths
46
+ if (
47
+ config.resources.lexicon_path is not None
48
+ and config.resources.lexicon_path.is_absolute()
49
+ and not config.resources.lexicon_path.exists()
50
+ ):
51
+ errors.append(f"lexicon_path does not exist: {config.resources.lexicon_path}")
52
+
53
+ if (
54
+ config.resources.templates_path is not None
55
+ and config.resources.templates_path.is_absolute()
56
+ and not config.resources.templates_path.exists()
57
+ ):
58
+ errors.append(
59
+ f"templates_path does not exist: {config.resources.templates_path}"
60
+ )
61
+
62
+ if (
63
+ config.resources.constraints_path is not None
64
+ and config.resources.constraints_path.is_absolute()
65
+ and not config.resources.constraints_path.exists()
66
+ ):
67
+ errors.append(
68
+ f"constraints_path does not exist: {config.resources.constraints_path}"
69
+ )
70
+
71
+ # check training logging dir
72
+ if (
73
+ config.active_learning.trainer.logging_dir.is_absolute()
74
+ and not config.active_learning.trainer.logging_dir.exists()
75
+ ):
76
+ errors.append(
77
+ f"logging_dir does not exist: {config.active_learning.trainer.logging_dir}"
78
+ )
79
+
80
+ # check logging file parent directory
81
+ if (
82
+ config.logging.file is not None
83
+ and config.logging.file.is_absolute()
84
+ and not config.logging.file.parent.exists()
85
+ ):
86
+ parent_dir = config.logging.file.parent
87
+ errors.append(f"logging file parent directory does not exist: {parent_dir}")
88
+
89
+ return errors
90
+
91
+
92
+ def check_resource_compatibility(config: BeadConfig) -> list[str]:
93
+ """Verify resources are compatible with templates.
94
+
95
+ Parameters
96
+ ----------
97
+ config : BeadConfig
98
+ Configuration to check.
99
+
100
+ Returns
101
+ -------
102
+ list[str]
103
+ List of resource compatibility errors.
104
+ """
105
+ errors: list[str] = []
106
+
107
+ # check that if templates_path is specified, lexicon_path should also be specified
108
+ if (
109
+ config.resources.templates_path is not None
110
+ and config.resources.lexicon_path is None
111
+ ):
112
+ errors.append(
113
+ "templates_path is specified but lexicon_path is not. "
114
+ "Templates require a lexicon."
115
+ )
116
+
117
+ return errors
118
+
119
+
120
+ def check_model_configuration(config: BeadConfig) -> list[str]:
121
+ """Verify model settings are valid.
122
+
123
+ Parameters
124
+ ----------
125
+ config : BeadConfig
126
+ Configuration to check.
127
+
128
+ Returns
129
+ -------
130
+ list[str]
131
+ List of model configuration errors.
132
+ """
133
+ try:
134
+ import torch # noqa: PLC0415
135
+ except ImportError:
136
+ torch = None # type: ignore[assignment]
137
+
138
+ errors: list[str] = []
139
+
140
+ # check CUDA availability if device is set to cuda
141
+ if config.items.model.device == "cuda":
142
+ if torch is None:
143
+ errors.append(
144
+ "Model device is set to 'cuda' but PyTorch is not installed. "
145
+ "Install PyTorch or set device to 'cpu'."
146
+ )
147
+ elif not torch.cuda.is_available(): # type: ignore[no-untyped-call]
148
+ errors.append(
149
+ "Model device is set to 'cuda' but CUDA is not available. "
150
+ "Set device to 'cpu' or install CUDA."
151
+ )
152
+
153
+ # check MPS availability if device is set to mps
154
+ if config.items.model.device == "mps":
155
+ if torch is None:
156
+ errors.append(
157
+ "Model device is set to 'mps' but PyTorch is not installed. "
158
+ "Install PyTorch or set device to 'cpu'."
159
+ )
160
+ elif not torch.backends.mps.is_available(): # type: ignore[no-untyped-call]
161
+ errors.append(
162
+ "Model device is set to 'mps' but MPS is not available. "
163
+ "Set device to 'cpu' or use a macOS system with MPS support."
164
+ )
165
+
166
+ return errors
167
+
168
+
169
+ def check_training_configuration(config: BeadConfig) -> list[str]:
170
+ """Verify training settings are compatible.
171
+
172
+ Parameters
173
+ ----------
174
+ config : BeadConfig
175
+ Configuration to check.
176
+
177
+ Returns
178
+ -------
179
+ list[str]
180
+ List of training configuration errors.
181
+ """
182
+ errors: list[str] = []
183
+
184
+ # check that batch size is positive
185
+ if config.active_learning.forced_choice_model.batch_size <= 0:
186
+ batch_size = config.active_learning.forced_choice_model.batch_size
187
+ errors.append(f"Training batch size must be positive, got {batch_size}")
188
+
189
+ # check that epochs is positive
190
+ if config.active_learning.trainer.epochs <= 0:
191
+ epochs = config.active_learning.trainer.epochs
192
+ errors.append(f"Training epochs must be positive, got {epochs}")
193
+
194
+ # check that learning rate is positive
195
+ if config.active_learning.forced_choice_model.learning_rate <= 0:
196
+ lr = config.active_learning.forced_choice_model.learning_rate
197
+ errors.append(f"Training learning rate must be positive, got {lr}")
198
+
199
+ return errors
200
+
201
+
202
+ def check_deployment_configuration(config: BeadConfig) -> list[str]:
203
+ """Verify deployment settings are valid.
204
+
205
+ Parameters
206
+ ----------
207
+ config : BeadConfig
208
+ Configuration to check.
209
+
210
+ Returns
211
+ -------
212
+ list[str]
213
+ List of deployment configuration errors.
214
+ """
215
+ errors: list[str] = []
216
+
217
+ # check jsPsych version format if platform is jspsych
218
+ if config.deployment.platform == "jspsych":
219
+ version = config.deployment.jspsych_version
220
+ if version is None: # type: ignore[reportUnnecessaryComparison]
221
+ errors.append("jsPsych platform requires jspsych_version to be specified")
222
+ elif not isinstance(version, str): # type: ignore[reportUnnecessaryIsInstance]
223
+ errors.append(
224
+ f"jspsych_version must be a string, got {type(version).__name__}"
225
+ )
226
+
227
+ return errors
228
+
229
+
230
+ def validate_config(config: BeadConfig) -> list[str]:
231
+ """Perform pre-flight validation on configuration.
232
+
233
+ Checks:
234
+ - All paths exist (if absolute paths are specified)
235
+ - Resource paths exist (if specified)
236
+ - Model configurations are compatible
237
+ - Training configurations are valid
238
+ - No conflicting settings
239
+
240
+ Parameters
241
+ ----------
242
+ config : BeadConfig
243
+ Configuration to validate.
244
+
245
+ Returns
246
+ -------
247
+ list[str]
248
+ List of validation errors. Empty if valid.
249
+
250
+ Examples
251
+ --------
252
+ >>> from bead.config import get_default_config
253
+ >>> config = get_default_config()
254
+ >>> errors = validate_config(config)
255
+ >>> len(errors)
256
+ 0
257
+ """
258
+ errors: list[str] = []
259
+
260
+ # run all validation checks
261
+ errors.extend(check_paths_exist(config))
262
+ errors.extend(check_resource_compatibility(config))
263
+ errors.extend(check_model_configuration(config))
264
+ errors.extend(check_training_configuration(config))
265
+ errors.extend(check_deployment_configuration(config))
266
+
267
+ return errors