bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,779 @@
1
+ """Additional active learning CLI commands.
2
+
3
+ This module contains the select-items and run commands that were too large
4
+ to include in the main active_learning.py file.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import traceback
11
+ from pathlib import Path
12
+ from typing import Any, Literal
13
+
14
+ import click
15
+ import numpy as np
16
+ import yaml
17
+ from pydantic import BaseModel, Field, ValidationError
18
+ from rich.console import Console
19
+ from rich.progress import (
20
+ BarColumn,
21
+ Progress,
22
+ SpinnerColumn,
23
+ TaskProgressColumn,
24
+ TextColumn,
25
+ )
26
+ from rich.table import Table
27
+
28
+ from bead.active_learning.loop import ActiveLearningLoop
29
+ from bead.active_learning.models.base import ActiveLearningModel
30
+ from bead.active_learning.models.binary import BinaryModel
31
+ from bead.active_learning.models.categorical import CategoricalModel
32
+ from bead.active_learning.models.forced_choice import ForcedChoiceModel
33
+ from bead.active_learning.selection import ItemSelector, UncertaintySampler
34
+ from bead.cli.utils import print_error, print_info, print_success
35
+ from bead.config.active_learning import (
36
+ ActiveLearningLoopConfig,
37
+ BinaryModelConfig,
38
+ CategoricalModelConfig,
39
+ ForcedChoiceModelConfig,
40
+ UncertaintySamplerConfig,
41
+ )
42
+ from bead.items.item import Item
43
+ from bead.items.item_template import ItemTemplate
44
+
45
+ console = Console()
46
+
47
+
48
+ # Configuration models for the run command
49
+ StoppingCriterion = Literal["max_iterations", "convergence", "performance_threshold"]
50
+
51
+
52
+ class RunLoopConfig(BaseModel):
53
+ """Loop configuration for active learning run command."""
54
+
55
+ max_iterations: int = Field(default=10, gt=0)
56
+ budget_per_iteration: int = Field(default=100, gt=0)
57
+ stopping_criterion: StoppingCriterion = "max_iterations"
58
+ performance_threshold: float | None = Field(default=None, ge=0, le=1)
59
+ metric_name: str = "accuracy"
60
+ convergence_patience: int = Field(default=3, gt=0)
61
+ convergence_threshold: float = Field(default=0.01, gt=0)
62
+
63
+
64
+ class RunModelConfig(BaseModel):
65
+ """Model configuration for active learning run command."""
66
+
67
+ type: Literal["binary", "categorical", "forced_choice"] = "binary"
68
+ model_name: str = "bert-base-uncased"
69
+ max_length: int = Field(default=128, gt=0)
70
+ learning_rate: float = Field(default=2e-5, gt=0)
71
+ batch_size: int = Field(default=16, gt=0)
72
+ num_epochs: int = Field(default=3, gt=0)
73
+ device: Literal["cpu", "cuda", "mps"] = "cpu"
74
+
75
+
76
+ class RunSelectionConfig(BaseModel):
77
+ """Selection configuration for active learning run command."""
78
+
79
+ method: Literal["entropy", "margin", "least_confidence"] = "entropy"
80
+ batch_size: int | None = Field(default=None, gt=0)
81
+
82
+
83
+ class RunDataConfig(BaseModel):
84
+ """Data paths configuration for active learning run command."""
85
+
86
+ initial_items: Path
87
+ unlabeled_pool: Path
88
+ item_template: Path
89
+ human_ratings: Path | None = None
90
+
91
+
92
+ class ActiveLearningRunConfig(BaseModel):
93
+ """Full configuration for active learning run command."""
94
+
95
+ loop: RunLoopConfig = Field(default_factory=RunLoopConfig)
96
+ model: RunModelConfig = Field(default_factory=RunModelConfig)
97
+ selection: RunSelectionConfig = Field(default_factory=RunSelectionConfig)
98
+ data: RunDataConfig
99
+
100
+
101
+ def load_run_config(config_path: Path) -> ActiveLearningRunConfig:
102
+ """Load active learning run configuration from YAML file.
103
+
104
+ Parameters
105
+ ----------
106
+ config_path : Path
107
+ Path to YAML configuration file.
108
+
109
+ Returns
110
+ -------
111
+ ActiveLearningRunConfig
112
+ Validated configuration.
113
+
114
+ Raises
115
+ ------
116
+ FileNotFoundError
117
+ If configuration file doesn't exist.
118
+ ValidationError
119
+ If configuration is invalid.
120
+ """
121
+ with open(config_path, encoding="utf-8") as f:
122
+ config_dict = yaml.safe_load(f)
123
+ return ActiveLearningRunConfig(**config_dict)
124
+
125
+
126
+ def create_model_from_config(model_config: RunModelConfig) -> ActiveLearningModel:
127
+ """Create model instance from configuration.
128
+
129
+ Parameters
130
+ ----------
131
+ model_config : RunModelConfig
132
+ Model configuration.
133
+
134
+ Returns
135
+ -------
136
+ ActiveLearningModel
137
+ Configured model instance.
138
+
139
+ Raises
140
+ ------
141
+ ValueError
142
+ If model type is unknown.
143
+ """
144
+ model_type = model_config.type
145
+
146
+ if model_type == "binary":
147
+ config = BinaryModelConfig(
148
+ model_name=model_config.model_name,
149
+ max_length=model_config.max_length,
150
+ learning_rate=model_config.learning_rate,
151
+ batch_size=model_config.batch_size,
152
+ num_epochs=model_config.num_epochs,
153
+ device=model_config.device,
154
+ )
155
+ return BinaryModel(config=config)
156
+ elif model_type == "categorical":
157
+ config = CategoricalModelConfig(
158
+ model_name=model_config.model_name,
159
+ max_length=model_config.max_length,
160
+ learning_rate=model_config.learning_rate,
161
+ batch_size=model_config.batch_size,
162
+ num_epochs=model_config.num_epochs,
163
+ device=model_config.device,
164
+ )
165
+ return CategoricalModel(config=config)
166
+ elif model_type == "forced_choice":
167
+ config = ForcedChoiceModelConfig(
168
+ model_name=model_config.model_name,
169
+ max_length=model_config.max_length,
170
+ learning_rate=model_config.learning_rate,
171
+ batch_size=model_config.batch_size,
172
+ num_epochs=model_config.num_epochs,
173
+ device=model_config.device,
174
+ )
175
+ return ForcedChoiceModel(config=config)
176
+ else:
177
+ raise ValueError(f"Unknown model type: {model_type}")
178
+
179
+
180
+ def _load_items(path: Path) -> list[Item]:
181
+ """Load items from JSONL file.
182
+
183
+ Parameters
184
+ ----------
185
+ path : Path
186
+ Path to JSONL file.
187
+
188
+ Returns
189
+ -------
190
+ list[Item]
191
+ List of loaded items.
192
+ """
193
+ items: list[Item] = []
194
+ with open(path, encoding="utf-8") as f:
195
+ for line in f:
196
+ line = line.strip()
197
+ if not line:
198
+ continue
199
+ item_data = json.loads(line)
200
+ items.append(Item(**item_data))
201
+ return items
202
+
203
+
204
+ def _load_item_template(path: Path) -> ItemTemplate:
205
+ """Load item template from JSONL file.
206
+
207
+ Parameters
208
+ ----------
209
+ path : Path
210
+ Path to JSONL file containing template.
211
+
212
+ Returns
213
+ -------
214
+ ItemTemplate
215
+ Loaded item template.
216
+
217
+ Raises
218
+ ------
219
+ ValueError
220
+ If no template found in file.
221
+ """
222
+ with open(path, encoding="utf-8") as f:
223
+ for line in f:
224
+ line = line.strip()
225
+ if not line:
226
+ continue
227
+ template_data = json.loads(line)
228
+ return ItemTemplate(**template_data)
229
+ raise ValueError(f"No item template found in {path}")
230
+
231
+
232
+ def _load_ratings(path: Path) -> dict[str, Any]:
233
+ """Load human ratings from JSONL file.
234
+
235
+ Parameters
236
+ ----------
237
+ path : Path
238
+ Path to JSONL file with ratings.
239
+
240
+ Returns
241
+ -------
242
+ dict[str, Any]
243
+ Mapping from item_id to label.
244
+ """
245
+ ratings: dict[str, Any] = {}
246
+ with open(path, encoding="utf-8") as f:
247
+ for line in f:
248
+ line = line.strip()
249
+ if not line:
250
+ continue
251
+ record = json.loads(line)
252
+ item_id = str(record["item_id"])
253
+ label = record["label"]
254
+ ratings[item_id] = label
255
+ return ratings
256
+
257
+
258
+ def _save_iteration_results(
259
+ output_dir: Path,
260
+ iteration: int,
261
+ selected_items: list[Item],
262
+ metrics: dict[str, float] | None,
263
+ ) -> None:
264
+ """Save results from a single iteration.
265
+
266
+ Parameters
267
+ ----------
268
+ output_dir : Path
269
+ Output directory.
270
+ iteration : int
271
+ Iteration number.
272
+ selected_items : list[Item]
273
+ Items selected in this iteration.
274
+ metrics : dict[str, float] | None
275
+ Training metrics (if available).
276
+ """
277
+ iter_dir = output_dir / f"iteration_{iteration}"
278
+ iter_dir.mkdir(parents=True, exist_ok=True)
279
+
280
+ # Save selected items
281
+ items_path = iter_dir / "selected_items.jsonl"
282
+ with open(items_path, "w", encoding="utf-8") as f:
283
+ for item in selected_items:
284
+ f.write(item.model_dump_json() + "\n")
285
+
286
+ # Save metrics if available
287
+ if metrics:
288
+ metrics_path = iter_dir / "metrics.json"
289
+ with open(metrics_path, "w", encoding="utf-8") as f:
290
+ json.dump(metrics, f, indent=2)
291
+
292
+
293
+ def _display_run_summary(
294
+ iterations_completed: int,
295
+ total_items_selected: int,
296
+ final_metrics: dict[str, float] | None,
297
+ ) -> None:
298
+ """Display summary table of active learning run.
299
+
300
+ Parameters
301
+ ----------
302
+ iterations_completed : int
303
+ Number of iterations completed.
304
+ total_items_selected : int
305
+ Total number of items selected.
306
+ final_metrics : dict[str, float] | None
307
+ Final model metrics.
308
+ """
309
+ table = Table(title="Active Learning Run Summary")
310
+ table.add_column("Metric", style="cyan")
311
+ table.add_column("Value", style="green", justify="right")
312
+
313
+ table.add_row("Iterations Completed", str(iterations_completed))
314
+ table.add_row("Total Items Selected", str(total_items_selected))
315
+
316
+ if final_metrics:
317
+ for metric_name, value in final_metrics.items():
318
+ if isinstance(value, float):
319
+ table.add_row(f"Final {metric_name}", f"{value:.4f}")
320
+
321
+ console.print(table)
322
+
323
+
324
+ def _show_dry_run_plan(
325
+ config: ActiveLearningRunConfig,
326
+ output_dir: Path,
327
+ ) -> None:
328
+ """Show what would be executed in dry-run mode.
329
+
330
+ Parameters
331
+ ----------
332
+ config : ActiveLearningRunConfig
333
+ Run configuration.
334
+ output_dir : Path
335
+ Output directory.
336
+ """
337
+ console.print("\n[yellow]DRY RUN MODE - No commands will be executed[/yellow]\n")
338
+
339
+ console.print("[bold]Configuration Summary:[/bold]")
340
+ console.print(f" Model type: {config.model.type}")
341
+ console.print(f" Model name: {config.model.model_name}")
342
+ console.print(f" Max iterations: {config.loop.max_iterations}")
343
+ console.print(f" Budget per iteration: {config.loop.budget_per_iteration}")
344
+ console.print(f" Stopping criterion: {config.loop.stopping_criterion}")
345
+ console.print(f" Selection method: {config.selection.method}")
346
+
347
+ console.print("\n[bold]Data Paths:[/bold]")
348
+ console.print(f" Initial items: {config.data.initial_items}")
349
+ console.print(f" Unlabeled pool: {config.data.unlabeled_pool}")
350
+ console.print(f" Item template: {config.data.item_template}")
351
+ ratings_path = config.data.human_ratings or "None (required for simulation)"
352
+ console.print(f" Human ratings: {ratings_path}")
353
+
354
+ console.print(f"\n[bold]Output directory:[/bold] {output_dir}")
355
+
356
+
357
+ @click.command()
358
+ @click.option(
359
+ "--items",
360
+ type=click.Path(exists=True, path_type=Path),
361
+ required=True,
362
+ help="Path to unlabeled items file (JSONL)",
363
+ )
364
+ @click.option(
365
+ "--model",
366
+ type=click.Path(exists=True, path_type=Path),
367
+ required=True,
368
+ help="Path to trained model directory",
369
+ )
370
+ @click.option(
371
+ "--output",
372
+ "-o",
373
+ type=click.Path(path_type=Path),
374
+ required=True,
375
+ help="Output file for selected items (JSONL)",
376
+ )
377
+ @click.option(
378
+ "--budget",
379
+ type=int,
380
+ required=True,
381
+ help="Number of items to select",
382
+ )
383
+ @click.option(
384
+ "--method",
385
+ type=click.Choice(["entropy", "margin", "least_confidence"]),
386
+ default="entropy",
387
+ help="Uncertainty sampling method (default: entropy)",
388
+ )
389
+ @click.pass_context
390
+ def select_items(
391
+ ctx: click.Context,
392
+ items: Path,
393
+ model: Path,
394
+ output: Path,
395
+ budget: int,
396
+ method: str,
397
+ ) -> None:
398
+ r"""Select items for annotation using active learning.
399
+
400
+ Uses uncertainty sampling to select the most informative items from
401
+ an unlabeled pool for human annotation.
402
+
403
+ Parameters
404
+ ----------
405
+ ctx : click.Context
406
+ Click context object.
407
+ items : Path
408
+ Path to unlabeled items file (JSONL).
409
+ model : Path
410
+ Path to trained model directory.
411
+ output : Path
412
+ Output file for selected items (JSONL).
413
+ budget : int
414
+ Number of items to select.
415
+ method : str
416
+ Uncertainty sampling method.
417
+
418
+ Examples
419
+ --------
420
+ $ bead active-learning select-items \\
421
+ --items unlabeled_items.jsonl \\
422
+ --model models/binary_model \\
423
+ --output selected_items.jsonl \\
424
+ --budget 50 \\
425
+ --method entropy
426
+ """
427
+ try:
428
+ console.rule("[bold]Item Selection[/bold]")
429
+
430
+ # Load items
431
+ print_info(f"Loading items from {items}")
432
+ unlabeled_items: list[Item] = []
433
+ with open(items, encoding="utf-8") as f:
434
+ for line in f:
435
+ line = line.strip()
436
+ if not line:
437
+ continue
438
+ item_data = json.loads(line)
439
+ item = Item(**item_data)
440
+ unlabeled_items.append(item)
441
+
442
+ if len(unlabeled_items) == 0:
443
+ print_error("No items found in file")
444
+ ctx.exit(1)
445
+
446
+ print_success(f"Loaded {len(unlabeled_items)} unlabeled items")
447
+
448
+ if budget > len(unlabeled_items):
449
+ n_available = len(unlabeled_items)
450
+ print_error(f"Budget ({budget}) exceeds available items ({n_available})")
451
+ ctx.exit(1)
452
+
453
+ # Load model
454
+ print_info(f"Loading model from {model}")
455
+ # Try to determine model type from config
456
+ config_path = model / "config.json"
457
+ if not config_path.exists():
458
+ print_error(f"Model config not found at {config_path}")
459
+ ctx.exit(1)
460
+
461
+ with open(config_path, encoding="utf-8") as f:
462
+ config_dict = json.load(f)
463
+
464
+ # Determine model type and load
465
+ model_type = config_dict.get("model_type") or config_dict.get("task_type")
466
+ cfg = str(config_dict)
467
+ is_binary = model_type == "binary" or "BinaryModelConfig" in cfg
468
+ is_categorical = model_type == "categorical" or "CategoricalModelConfig" in cfg
469
+ is_forced = model_type == "forced_choice" or "ForcedChoiceModelConfig" in cfg
470
+
471
+ if is_binary:
472
+ loaded_model = BinaryModel()
473
+ loaded_model.load(str(model))
474
+ elif is_categorical:
475
+ loaded_model = CategoricalModel()
476
+ loaded_model.load(str(model))
477
+ elif is_forced:
478
+ loaded_model = ForcedChoiceModel()
479
+ loaded_model.load(str(model))
480
+ else:
481
+ # Default to binary
482
+ loaded_model = BinaryModel()
483
+ loaded_model.load(str(model))
484
+
485
+ print_success("Model loaded successfully")
486
+
487
+ # Create item selector
488
+ sampler_config = UncertaintySamplerConfig(method=method)
489
+ selector: ItemSelector = UncertaintySampler(config=sampler_config)
490
+
491
+ # Define predict function
492
+ def predict_fn(model_instance: object, item: Item) -> np.ndarray:
493
+ """Get prediction probabilities for an item."""
494
+ predictions = model_instance.predict_proba([item], participant_ids=None)
495
+ return predictions[0]
496
+
497
+ # Select items
498
+ print_info(f"Selecting {budget} items using {method} method...")
499
+ with Progress(
500
+ SpinnerColumn(),
501
+ TextColumn("[progress.description]{task.description}"),
502
+ console=console,
503
+ ) as progress:
504
+ progress.add_task("Selecting items...", total=None)
505
+ selected_items = selector.select(
506
+ items=unlabeled_items,
507
+ model=loaded_model,
508
+ predict_fn=predict_fn,
509
+ budget=budget,
510
+ )
511
+
512
+ print_success(f"Selected {len(selected_items)} items")
513
+
514
+ # Save selected items
515
+ print_info(f"Writing selected items to {output}")
516
+ output.parent.mkdir(parents=True, exist_ok=True)
517
+ with open(output, "w", encoding="utf-8") as f:
518
+ for item in selected_items:
519
+ f.write(item.model_dump_json() + "\n")
520
+
521
+ print_success(f"Selected items written to {output}")
522
+
523
+ # Display summary
524
+ table = Table(title="Selection Summary")
525
+ table.add_column("Metric", style="cyan")
526
+ table.add_column("Value", style="green", justify="right")
527
+
528
+ table.add_row("Total unlabeled items", str(len(unlabeled_items)))
529
+ table.add_row("Budget", str(budget))
530
+ table.add_row("Selected items", str(len(selected_items)))
531
+ table.add_row("Method", method)
532
+
533
+ console.print(table)
534
+
535
+ except FileNotFoundError as e:
536
+ print_error(f"File not found: {e}")
537
+ ctx.exit(1)
538
+ except json.JSONDecodeError as e:
539
+ print_error(f"Invalid JSON: {e}")
540
+ ctx.exit(1)
541
+ except Exception as e:
542
+ print_error(f"Item selection failed: {e}")
543
+ traceback.print_exc()
544
+ ctx.exit(1)
545
+
546
+
547
+ @click.command()
548
+ @click.option(
549
+ "--config",
550
+ type=click.Path(exists=True, path_type=Path),
551
+ required=True,
552
+ help="Path to active learning configuration file (YAML)",
553
+ )
554
+ @click.option(
555
+ "--output-dir",
556
+ type=click.Path(path_type=Path),
557
+ required=True,
558
+ help="Output directory for active learning results",
559
+ )
560
+ @click.option(
561
+ "--mode",
562
+ type=click.Choice(["simulation"]),
563
+ default="simulation",
564
+ help="Execution mode: simulation (with ratings file)",
565
+ )
566
+ @click.option(
567
+ "--dry-run",
568
+ is_flag=True,
569
+ default=False,
570
+ help="Show what would be done without executing",
571
+ )
572
+ @click.pass_context
573
+ def run(
574
+ ctx: click.Context,
575
+ config: Path,
576
+ output_dir: Path,
577
+ mode: str,
578
+ dry_run: bool,
579
+ ) -> None:
580
+ r"""Run full active learning loop.
581
+
582
+ Orchestrates the complete active learning workflow:
583
+ 1. Select informative items using uncertainty sampling
584
+ 2. Simulate data collection using provided human ratings
585
+ 3. Train model on labeled data
586
+ 4. Check convergence
587
+ 5. Repeat until convergence or max iterations
588
+
589
+ Note: Currently only simulation mode is supported, which requires
590
+ a human_ratings file in the configuration. Automated data collection
591
+ via JATOS/Prolific is not yet implemented.
592
+
593
+ Parameters
594
+ ----------
595
+ ctx : click.Context
596
+ Click context object.
597
+ config : Path
598
+ Path to active learning configuration file (YAML).
599
+ output_dir : Path
600
+ Output directory for results.
601
+ mode : str
602
+ Execution mode (currently only "simulation").
603
+ dry_run : bool
604
+ If True, show plan without executing.
605
+
606
+ Examples
607
+ --------
608
+ $ bead active-learning run \\
609
+ --config configs/active_learning.yaml \\
610
+ --output-dir results/
611
+
612
+ $ bead active-learning run \\
613
+ --config configs/active_learning.yaml \\
614
+ --output-dir results/ \\
615
+ --dry-run
616
+ """
617
+ try:
618
+ console.rule("[bold]Active Learning Loop[/bold]")
619
+
620
+ # Step 1: Load and validate configuration
621
+ print_info(f"Loading configuration from {config}")
622
+ try:
623
+ run_config = load_run_config(config)
624
+ except ValidationError as e:
625
+ print_error(f"Configuration validation error: {e}")
626
+ ctx.exit(1)
627
+ return
628
+
629
+ # Step 2: Validate mode requirements
630
+ if mode == "simulation" and run_config.data.human_ratings is None:
631
+ print_error("Simulation mode requires human_ratings path in config")
632
+ print_info("Add 'human_ratings: path/to/ratings.jsonl' to data section")
633
+ ctx.exit(1)
634
+ return
635
+
636
+ # Step 3: Create output directory
637
+ output_dir.mkdir(parents=True, exist_ok=True)
638
+
639
+ # Step 4: Handle dry run
640
+ if dry_run:
641
+ _show_dry_run_plan(run_config, output_dir)
642
+ return
643
+
644
+ # Step 5: Load data
645
+ print_info(f"Loading initial items from {run_config.data.initial_items}")
646
+ initial_items = _load_items(run_config.data.initial_items)
647
+ print_success(f"Loaded {len(initial_items)} initial items")
648
+
649
+ print_info(f"Loading unlabeled pool from {run_config.data.unlabeled_pool}")
650
+ unlabeled_pool = _load_items(run_config.data.unlabeled_pool)
651
+ print_success(f"Loaded {len(unlabeled_pool)} unlabeled items")
652
+
653
+ print_info(f"Loading item template from {run_config.data.item_template}")
654
+ item_template = _load_item_template(run_config.data.item_template)
655
+ print_success("Loaded item template")
656
+
657
+ human_ratings: dict[str, Any] | None = None
658
+ if run_config.data.human_ratings:
659
+ print_info(f"Loading human ratings from {run_config.data.human_ratings}")
660
+ human_ratings = _load_ratings(run_config.data.human_ratings)
661
+ print_success(f"Loaded {len(human_ratings)} human ratings")
662
+
663
+ # Step 6: Create model
664
+ print_info(f"Creating {run_config.model.type} model...")
665
+ model = create_model_from_config(run_config.model)
666
+ print_success("Model created")
667
+
668
+ # Step 7: Create item selector
669
+ sampler_config = UncertaintySamplerConfig(
670
+ method=run_config.selection.method,
671
+ batch_size=run_config.selection.batch_size,
672
+ )
673
+ item_selector: ItemSelector = UncertaintySampler(config=sampler_config)
674
+
675
+ # Step 8: Create loop config
676
+ loop_config = ActiveLearningLoopConfig(
677
+ max_iterations=run_config.loop.max_iterations,
678
+ budget_per_iteration=run_config.loop.budget_per_iteration,
679
+ stopping_criterion=run_config.loop.stopping_criterion,
680
+ performance_threshold=run_config.loop.performance_threshold,
681
+ metric_name=run_config.loop.metric_name,
682
+ convergence_patience=run_config.loop.convergence_patience,
683
+ convergence_threshold=run_config.loop.convergence_threshold,
684
+ )
685
+
686
+ # Step 9: Create and run loop
687
+ print_info("Initializing active learning loop...")
688
+ loop = ActiveLearningLoop(
689
+ item_selector=item_selector,
690
+ config=loop_config,
691
+ )
692
+
693
+ # Step 10: Run with progress reporting
694
+ print_info("Starting active learning loop...")
695
+ console.print()
696
+
697
+ with Progress(
698
+ SpinnerColumn(),
699
+ TextColumn("[progress.description]{task.description}"),
700
+ BarColumn(),
701
+ TaskProgressColumn(),
702
+ console=console,
703
+ ) as progress:
704
+ task = progress.add_task(
705
+ "Running active learning...",
706
+ total=run_config.loop.max_iterations,
707
+ )
708
+
709
+ try:
710
+ loop.run(
711
+ initial_items=initial_items,
712
+ initial_model=model,
713
+ item_template=item_template,
714
+ unlabeled_pool=unlabeled_pool,
715
+ human_ratings=human_ratings,
716
+ )
717
+
718
+ # Update progress based on actual iterations completed
719
+ iterations_completed = len(loop.iteration_history)
720
+ progress.update(task, completed=iterations_completed)
721
+
722
+ except Exception as e:
723
+ print_error(f"Active learning loop failed: {e}")
724
+ traceback.print_exc()
725
+ ctx.exit(1)
726
+ return
727
+
728
+ # Step 11: Save results
729
+ print_info("Saving results...")
730
+ total_items_selected = 0
731
+ final_metrics: dict[str, float] | None = None
732
+
733
+ for i, iteration_result in enumerate(loop.iteration_history):
734
+ selected_items = iteration_result.get("selected_items", [])
735
+ total_items_selected += len(selected_items)
736
+ metrics = iteration_result.get("metrics")
737
+
738
+ _save_iteration_results(
739
+ output_dir=output_dir,
740
+ iteration=i,
741
+ selected_items=selected_items,
742
+ metrics=metrics,
743
+ )
744
+
745
+ if metrics:
746
+ final_metrics = metrics
747
+
748
+ # Save run summary
749
+ summary = {
750
+ "iterations_completed": len(loop.iteration_history),
751
+ "total_items_selected": total_items_selected,
752
+ "config": run_config.model_dump(mode="json"),
753
+ }
754
+ summary_path = output_dir / "run_summary.json"
755
+ with open(summary_path, "w", encoding="utf-8") as f:
756
+ json.dump(summary, f, indent=2, default=str)
757
+
758
+ print_success(f"Results saved to {output_dir}")
759
+
760
+ # Step 12: Display summary
761
+ console.print()
762
+ _display_run_summary(
763
+ iterations_completed=len(loop.iteration_history),
764
+ total_items_selected=total_items_selected,
765
+ final_metrics=final_metrics,
766
+ )
767
+
768
+ print_success("Active learning completed!")
769
+
770
+ except FileNotFoundError as e:
771
+ print_error(f"File not found: {e}")
772
+ ctx.exit(1)
773
+ except json.JSONDecodeError as e:
774
+ print_error(f"Invalid JSON: {e}")
775
+ ctx.exit(1)
776
+ except Exception as e:
777
+ print_error(f"Active learning run failed: {e}")
778
+ traceback.print_exc()
779
+ ctx.exit(1)