bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,566 @@
1
+ """Active learning loop orchestration.
2
+
3
+ This module orchestrates the iterative active learning loop (stages 3-6):
4
+ construct items → deploy experiment → collect data → train model → select
5
+ next items. It manages convergence detection and coordinates all components.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from datetime import UTC, datetime
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, TypedDict
13
+
14
+ import numpy as np
15
+ from sklearn.metrics import accuracy_score
16
+
17
+ from bead.active_learning.selection import ItemSelector
18
+ from bead.active_learning.trainers.base import ModelMetadata
19
+ from bead.data_collection.jatos import JATOSDataCollector
20
+ from bead.data_collection.merger import DataMerger
21
+ from bead.data_collection.prolific import ProlificDataCollector
22
+ from bead.evaluation.convergence import ConvergenceDetector
23
+ from bead.items.item import Item
24
+ from bead.items.item_template import ItemTemplate
25
+
26
+ if TYPE_CHECKING:
27
+ from bead.active_learning.models.base import ActiveLearningModel
28
+ from bead.config.active_learning import ActiveLearningLoopConfig
29
+
30
+
31
+ class IterationResult(TypedDict):
32
+ """Results from a single active learning iteration.
33
+
34
+ Attributes
35
+ ----------
36
+ iteration : int
37
+ Iteration number.
38
+ selected_items : list[Item]
39
+ Items selected for annotation in this iteration.
40
+ model : TwoAFCModel
41
+ Updated model after this iteration.
42
+ metadata : ModelMetadata | None
43
+ Training metadata if model was retrained, None otherwise.
44
+ """
45
+
46
+ iteration: int
47
+ selected_items: list[Item]
48
+ model: ActiveLearningModel
49
+ metadata: ModelMetadata | None
50
+
51
+
52
+ class ActiveLearningLoop:
53
+ """Orchestrates the active learning loop (stages 3-6).
54
+
55
+ Manages the iterative process of selecting informative items,
56
+ training models on collected data, and determining when to stop.
57
+
58
+ Note: Data collection integration is not yet implemented, so this
59
+ loop uses placeholder interfaces for deployment and data collection.
60
+ The focus is on the selection logic and loop orchestration.
61
+
62
+ Parameters
63
+ ----------
64
+ item_selector : ItemSelector
65
+ Algorithm for selecting informative items.
66
+ config : ActiveLearningLoopConfig | None
67
+ Configuration object. If None, uses default configuration.
68
+
69
+ Attributes
70
+ ----------
71
+ item_selector : ItemSelector
72
+ Item selection algorithm.
73
+ config : ActiveLearningLoopConfig
74
+ Loop configuration.
75
+ iteration_history : list[IterationResult]
76
+ History of all iterations with structured results.
77
+
78
+ Examples
79
+ --------
80
+ >>> from bead.active_learning.selection import UncertaintySampler
81
+ >>> from bead.config.active_learning import ActiveLearningLoopConfig
82
+ >>> import numpy as np
83
+ >>> selector = UncertaintySampler()
84
+ >>> config = ActiveLearningLoopConfig( # doctest: +SKIP
85
+ ... max_iterations=5,
86
+ ... budget_per_iteration=100
87
+ ... )
88
+ >>> loop = ActiveLearningLoop( # doctest: +SKIP
89
+ ... item_selector=selector,
90
+ ... config=config
91
+ ... )
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ item_selector: ItemSelector,
97
+ config: ActiveLearningLoopConfig | None = None,
98
+ ) -> None:
99
+ """Initialize active learning loop.
100
+
101
+ Parameters
102
+ ----------
103
+ item_selector : ItemSelector
104
+ Algorithm for selecting items.
105
+ config : ActiveLearningLoopConfig | None
106
+ Configuration object. If None, uses default configuration.
107
+ """
108
+ self.item_selector = item_selector
109
+ self.config = config or ActiveLearningLoopConfig()
110
+ self.iteration_history: list[IterationResult] = []
111
+
112
+ # Initialize data collectors if configured
113
+ self.jatos_collector: JATOSDataCollector | None = None
114
+ self.prolific_collector: ProlificDataCollector | None = None
115
+ self.data_merger: DataMerger | None = None
116
+
117
+ if self.config.jatos is not None:
118
+ self.jatos_collector = JATOSDataCollector(
119
+ base_url=self.config.jatos.base_url,
120
+ api_token=self.config.jatos.api_token,
121
+ study_id=self.config.jatos.study_id,
122
+ )
123
+
124
+ if self.config.prolific is not None:
125
+ self.prolific_collector = ProlificDataCollector(
126
+ api_key=self.config.prolific.api_key,
127
+ study_id=self.config.prolific.study_id,
128
+ )
129
+
130
+ if self.jatos_collector or self.prolific_collector:
131
+ self.data_merger = DataMerger()
132
+
133
+ def run(
134
+ self,
135
+ initial_items: list[Item],
136
+ initial_model: ActiveLearningModel,
137
+ item_template: ItemTemplate,
138
+ unlabeled_pool: list[Item],
139
+ human_ratings: dict[str, str] | None = None,
140
+ convergence_detector: ConvergenceDetector | None = None,
141
+ ) -> list[ModelMetadata]:
142
+ """Run the complete active learning loop.
143
+
144
+ Parameters
145
+ ----------
146
+ initial_items : list[Item]
147
+ Initial labeled items for training.
148
+ initial_model : ActiveLearningModel
149
+ Model instance to use for active learning.
150
+ item_template : ItemTemplate
151
+ Template used to construct all items. Required for validating
152
+ model compatibility with task type.
153
+ unlabeled_pool : list[Item]
154
+ Pool of unlabeled items to select from.
155
+ human_ratings : dict[str, str] | None
156
+ Human ratings mapping item_id to option names.
157
+ convergence_detector : ConvergenceDetector | None
158
+ Detector for checking convergence to human-level performance.
159
+ If provided, will check convergence after each iteration.
160
+
161
+ Returns
162
+ -------
163
+ list[ModelMetadata]
164
+ Metadata for all trained models across iterations.
165
+
166
+ Raises
167
+ ------
168
+ ValueError
169
+ If stopping_criterion is invalid or threshold not provided when needed.
170
+
171
+ Notes
172
+ -----
173
+ Stopping criteria and performance thresholds are configured via
174
+ the `config` parameter passed to __init__.
175
+
176
+ Examples
177
+ --------
178
+ >>> from uuid import uuid4
179
+ >>> from bead.items.item import Item
180
+ >>> from bead.config.active_learning import ActiveLearningLoopConfig
181
+ >>> selector = UncertaintySampler() # doctest: +SKIP
182
+ >>> config = ActiveLearningLoopConfig(max_iterations=3) # doctest: +SKIP
183
+ >>> loop = ActiveLearningLoop( # doctest: +SKIP
184
+ ... item_selector=selector,
185
+ ... config=config
186
+ ... )
187
+ >>> # Run would typically be called here with real data
188
+ """
189
+ # Validate inputs based on config
190
+ stopping_criterion = self.config.stopping_criterion
191
+ performance_threshold = self.config.performance_threshold
192
+ metric_name = self.config.metric_name
193
+
194
+ if (
195
+ stopping_criterion == "performance_threshold"
196
+ and performance_threshold is None
197
+ ):
198
+ raise ValueError(
199
+ "performance_threshold must be provided in config when using "
200
+ "performance_threshold stopping criterion"
201
+ )
202
+
203
+ if stopping_criterion == "convergence" and convergence_detector is None:
204
+ raise ValueError(
205
+ "convergence_detector must be provided when using "
206
+ "convergence stopping criterion"
207
+ )
208
+
209
+ current_model: ActiveLearningModel = initial_model
210
+
211
+ # Validate model compatibility with item task type
212
+ if item_template.task_type not in current_model.supported_task_types:
213
+ raise ValueError(
214
+ f"Model {type(current_model).__name__} does not support "
215
+ f"task type '{item_template.task_type}'. "
216
+ f"Supported types: {current_model.supported_task_types}"
217
+ )
218
+
219
+ # Validate all initial items for structural compatibility
220
+ for item in initial_items:
221
+ current_model.validate_item_compatibility(item, item_template)
222
+
223
+ # Validate all unlabeled items for structural compatibility
224
+ for item in unlabeled_pool:
225
+ current_model.validate_item_compatibility(item, item_template)
226
+
227
+ model_history: list[ModelMetadata] = []
228
+ current_unlabeled = unlabeled_pool.copy()
229
+ labeled_items = initial_items.copy()
230
+
231
+ # Check if we have any unlabeled items to start with
232
+ if not current_unlabeled:
233
+ return model_history
234
+
235
+ # Run iterations
236
+ for iteration in range(self.config.max_iterations):
237
+ # Extract labels for current labeled items
238
+ if human_ratings is None:
239
+ # No ratings provided, can't train
240
+ break
241
+
242
+ labels = [
243
+ human_ratings.get(str(item.id), "option_a") for item in labeled_items
244
+ ]
245
+
246
+ # Train model
247
+ train_metrics = current_model.train(items=labeled_items, labels=labels)
248
+
249
+ # Evaluate model on labeled items
250
+ predictions = current_model.predict(labeled_items)
251
+ pred_labels = [p.predicted_class for p in predictions]
252
+
253
+ # Compute accuracy
254
+ accuracy = accuracy_score(labels, pred_labels)
255
+
256
+ # Create metadata
257
+ training_config_dict = {
258
+ "iteration": iteration,
259
+ "max_iterations": self.config.max_iterations,
260
+ "budget_per_iteration": self.config.budget_per_iteration,
261
+ }
262
+
263
+ metadata = ModelMetadata(
264
+ model_name="ActiveLearningModel",
265
+ framework="custom",
266
+ training_config=training_config_dict,
267
+ training_data_path=Path("active_learning_data"),
268
+ metrics={"accuracy": accuracy, **train_metrics},
269
+ training_time=0.0,
270
+ training_timestamp=datetime.now(UTC).isoformat(),
271
+ )
272
+ model_history.append(metadata)
273
+
274
+ # Run one iteration for item selection
275
+ iteration_result = self.run_iteration(
276
+ iteration=iteration,
277
+ unlabeled_items=current_unlabeled,
278
+ current_model=current_model,
279
+ )
280
+
281
+ # Store results
282
+ self.iteration_history.append(iteration_result)
283
+
284
+ # Update state
285
+ selected_items = iteration_result["selected_items"]
286
+ current_model = iteration_result["model"]
287
+
288
+ # Add selected items to labeled set
289
+ labeled_items.extend(selected_items)
290
+
291
+ # Remove selected items from unlabeled pool
292
+ selected_ids = {item.id for item in selected_items}
293
+ current_unlabeled = [
294
+ item for item in current_unlabeled if item.id not in selected_ids
295
+ ]
296
+
297
+ # Check stopping criteria
298
+ if stopping_criterion == "max_iterations":
299
+ # Will stop naturally at max_iterations
300
+ pass
301
+ elif stopping_criterion == "performance_threshold":
302
+ if metadata and metric_name in metadata.metrics:
303
+ if metadata.metrics[metric_name] >= performance_threshold: # type: ignore
304
+ break
305
+ elif stopping_criterion == "convergence":
306
+ if convergence_detector is not None and metadata is not None:
307
+ # Compute human baseline on first iteration
308
+ if iteration == 0 and human_ratings is not None:
309
+ convergence_detector.compute_human_baseline(human_ratings)
310
+
311
+ # Check if converged
312
+ if metric_name in metadata.metrics:
313
+ converged = convergence_detector.check_convergence(
314
+ model_accuracy=metadata.metrics[metric_name],
315
+ iteration=iteration + 1,
316
+ )
317
+
318
+ if converged:
319
+ print(f"✓ Converged at iteration {iteration + 1}")
320
+ break
321
+
322
+ # Check if unlabeled pool is exhausted
323
+ if not current_unlabeled:
324
+ break
325
+
326
+ return model_history
327
+
328
+ def run_iteration(
329
+ self,
330
+ iteration: int,
331
+ unlabeled_items: list[Item],
332
+ current_model: ActiveLearningModel,
333
+ ) -> IterationResult:
334
+ """Run one iteration of the active learning loop.
335
+
336
+ Steps:
337
+ 1. Select informative items using uncertainty sampling
338
+ 2. (Placeholder) Deploy experiment for data collection
339
+ 3. (Placeholder) Wait for and collect data
340
+ 4. (Placeholder) Train new model on augmented dataset
341
+ 5. Return results
342
+
343
+ Parameters
344
+ ----------
345
+ iteration : int
346
+ Current iteration number.
347
+ unlabeled_items : list[Item]
348
+ Unlabeled items available for selection.
349
+ current_model : ActiveLearningModel
350
+ Current trained model for making predictions.
351
+
352
+ Returns
353
+ -------
354
+ IterationResult
355
+ Structured iteration results containing:
356
+ - iteration: Iteration number
357
+ - selected_items: List of selected items
358
+ - model: Updated model
359
+ - metadata: Training metadata if available
360
+
361
+ Examples
362
+ --------
363
+ >>> from uuid import uuid4
364
+ >>> from bead.items.item import Item
365
+ >>> import numpy as np
366
+ >>> selector = UncertaintySampler()
367
+ >>> loop = ActiveLearningLoop(
368
+ ... item_selector=selector,
369
+ ... trainer=None,
370
+ ... predict_fn=lambda m, i: np.array([0.5, 0.5]),
371
+ ... max_iterations=5,
372
+ ... budget_per_iteration=2
373
+ ... )
374
+ >>> items = [
375
+ ... Item(item_template_id=uuid4(), rendered_elements={})
376
+ ... for _ in range(5)
377
+ ... ]
378
+ >>> result = loop.run_iteration(0, items, None)
379
+ >>> len(result["selected_items"])
380
+ 2
381
+ >>> result["iteration"]
382
+ 0
383
+ """
384
+ # Step 1: Select items using active learning
385
+ budget = min(self.config.budget_per_iteration, len(unlabeled_items))
386
+
387
+ def model_predict_fn(model: ActiveLearningModel, item: Item) -> np.ndarray:
388
+ """Get prediction probabilities for a single item."""
389
+ proba = model.predict_proba([item])
390
+ return proba[0]
391
+
392
+ selected_items = self.item_selector.select(
393
+ items=unlabeled_items,
394
+ model=current_model,
395
+ predict_fn=model_predict_fn,
396
+ budget=budget,
397
+ )
398
+
399
+ # Step 2: Deploy experiment (PLACEHOLDER - data collection not yet implemented)
400
+ # In the future, this would:
401
+ # - Create experiment lists using ListPartitioner
402
+ # - Generate jsPsych experiment using JsPsychExperimentGenerator
403
+ # - Export to JATOS format
404
+ # - Return deployment info for manual upload
405
+
406
+ # Step 3: Collect data (PLACEHOLDER - data collection not yet implemented)
407
+ # In the future, this would:
408
+ # - Wait for participants to complete experiments
409
+ # - Use JATOSDataCollector to download results
410
+ # - Use ProlificDataCollector to get participant metadata
411
+ # - Use DataMerger to merge JATOS and Prolific data
412
+
413
+ # Step 4: Train new model (PLACEHOLDER - training data not available)
414
+ # In the future, this would:
415
+ # - Merge old training data with new collected data
416
+ # - Call trainer.train() with augmented dataset
417
+ # - Return updated model and metadata
418
+
419
+ # For now, return placeholder results
420
+ return IterationResult(
421
+ iteration=iteration,
422
+ selected_items=selected_items,
423
+ model=current_model, # Unchanged for now
424
+ metadata=None, # Would contain training metrics
425
+ )
426
+
427
+ def check_convergence(
428
+ self,
429
+ metrics_history: list[dict[str, float]],
430
+ metric_name: str = "accuracy",
431
+ patience: int = 3,
432
+ min_delta: float = 0.01,
433
+ ) -> bool:
434
+ """Check if model performance has converged.
435
+
436
+ Uses early stopping logic: if performance hasn't improved by
437
+ at least min_delta for patience iterations, consider converged.
438
+
439
+ For metrics where lower is better (like "loss"), the logic checks
440
+ if the best (lowest) value is from more than patience iterations ago.
441
+
442
+ Parameters
443
+ ----------
444
+ metrics_history : list[dict[str, float]]
445
+ History of metrics from each iteration.
446
+ metric_name : str
447
+ Name of metric to track.
448
+ patience : int
449
+ Number of iterations without improvement before stopping.
450
+ min_delta : float
451
+ Minimum change to count as improvement.
452
+
453
+ Returns
454
+ -------
455
+ bool
456
+ True if converged, False otherwise.
457
+
458
+ Examples
459
+ --------
460
+ >>> loop = ActiveLearningLoop( # doctest: +SKIP
461
+ ... item_selector=UncertaintySampler(),
462
+ ... trainer=None,
463
+ ... predict_fn=lambda m, i: np.array([0.5, 0.5])
464
+ ... )
465
+ >>> # Improving performance - not converged
466
+ >>> history = [
467
+ ... {"accuracy": 0.7},
468
+ ... {"accuracy": 0.75},
469
+ ... {"accuracy": 0.8}
470
+ ... ]
471
+ >>> loop.check_convergence(history, metric_name="accuracy", patience=2)
472
+ False
473
+ >>> # No improvement for 3 iterations - converged
474
+ >>> history = [
475
+ ... {"accuracy": 0.8},
476
+ ... {"accuracy": 0.81},
477
+ ... {"accuracy": 0.805},
478
+ ... {"accuracy": 0.81}
479
+ ... ]
480
+ >>> loop.check_convergence(
481
+ ... history, metric_name="accuracy", patience=3, min_delta=0.02
482
+ ... )
483
+ True
484
+ """
485
+ if len(metrics_history) < patience + 1:
486
+ return False
487
+
488
+ # Get recent metrics
489
+ recent_metrics = [m[metric_name] for m in metrics_history[-(patience + 1) :]]
490
+
491
+ # Determine if lower is better (like loss) or higher is better (like accuracy)
492
+ is_lower_better = metric_name.lower() in ["loss", "error", "mse", "rmse", "mae"]
493
+
494
+ if is_lower_better:
495
+ # For loss metrics, best means minimum
496
+ best_metric = min(recent_metrics)
497
+ best_idx = recent_metrics.index(best_metric)
498
+
499
+ # If best is from patience or more iterations ago, check convergence
500
+ if best_idx <= len(recent_metrics) - patience - 1:
501
+ # Check that degradation from best to current is < min_delta
502
+ current_metric = recent_metrics[-1]
503
+ degradation = current_metric - best_metric
504
+
505
+ if degradation >= min_delta:
506
+ return True # Performance degraded, converged
507
+
508
+ else:
509
+ # For accuracy metrics, best means maximum
510
+ best_metric = max(recent_metrics)
511
+ best_idx = recent_metrics.index(best_metric)
512
+
513
+ # If best is from patience or more iterations ago, converged
514
+ if best_idx <= len(recent_metrics) - patience - 1:
515
+ # Check that improvement from best to current is < min_delta
516
+ current_metric = recent_metrics[-1]
517
+ improvement = best_metric - current_metric
518
+
519
+ if improvement >= min_delta:
520
+ return False # Still improving
521
+
522
+ return True
523
+
524
+ return False
525
+
526
+ def get_summary(self) -> dict[str, int | dict[str, int]]:
527
+ """Get summary statistics of the active learning loop.
528
+
529
+ Returns
530
+ -------
531
+ dict[str, int | dict[str, int]]
532
+ Summary dictionary with the following keys:
533
+
534
+ total_iterations : int
535
+ Total number of iterations run.
536
+ total_items_selected : int
537
+ Total items selected across all iterations.
538
+ convergence_info : dict[str, int]
539
+ Configuration parameters (max_iterations, budget_per_iteration).
540
+
541
+ Examples
542
+ --------
543
+ >>> selector = UncertaintySampler()
544
+ >>> loop = ActiveLearningLoop(
545
+ ... item_selector=selector,
546
+ ... trainer=None,
547
+ ... predict_fn=lambda m, i: np.array([0.5, 0.5])
548
+ ... )
549
+ >>> summary = loop.get_summary()
550
+ >>> summary["total_iterations"]
551
+ 0
552
+ >>> summary["total_items_selected"]
553
+ 0
554
+ """
555
+ total_items = sum(
556
+ len(iteration["selected_items"]) for iteration in self.iteration_history
557
+ )
558
+
559
+ return {
560
+ "total_iterations": len(self.iteration_history),
561
+ "total_items_selected": total_items,
562
+ "convergence_info": {
563
+ "max_iterations": self.config.max_iterations,
564
+ "budget_per_iteration": self.config.budget_per_iteration,
565
+ },
566
+ }
@@ -0,0 +1,24 @@
1
+ """Active learning models for different task types."""
2
+
3
+ from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
4
+ from bead.active_learning.models.binary import BinaryModel
5
+ from bead.active_learning.models.categorical import CategoricalModel
6
+ from bead.active_learning.models.cloze import ClozeModel
7
+ from bead.active_learning.models.forced_choice import ForcedChoiceModel
8
+ from bead.active_learning.models.free_text import FreeTextModel
9
+ from bead.active_learning.models.magnitude import MagnitudeModel
10
+ from bead.active_learning.models.multi_select import MultiSelectModel
11
+ from bead.active_learning.models.ordinal_scale import OrdinalScaleModel
12
+
13
+ __all__ = [
14
+ "ActiveLearningModel",
15
+ "BinaryModel",
16
+ "CategoricalModel",
17
+ "ClozeModel",
18
+ "ForcedChoiceModel",
19
+ "FreeTextModel",
20
+ "MagnitudeModel",
21
+ "ModelPrediction",
22
+ "MultiSelectModel",
23
+ "OrdinalScaleModel",
24
+ ]