bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,391 @@
1
+ """Sampling strategies for active learning.
2
+
3
+ This module implements various uncertainty quantification methods for
4
+ active learning item selection, including entropy, margin, and least
5
+ confidence sampling, plus a random baseline.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from abc import ABC, abstractmethod
11
+ from typing import Literal
12
+
13
+ import numpy as np
14
+
15
+
16
+ class SamplingStrategy(ABC):
17
+ """Base class for active learning sampling strategies.
18
+
19
+ All sampling strategies must implement compute_scores to quantify
20
+ uncertainty or informativeness of predictions, and select_top_k
21
+ to select the most informative items.
22
+
23
+ Examples
24
+ --------
25
+ >>> import numpy as np
26
+ >>> class MyStrategy(SamplingStrategy):
27
+ ... def compute_scores(self, probabilities):
28
+ ... return np.max(probabilities, axis=1)
29
+ >>> strategy = MyStrategy()
30
+ >>> probs = np.array([[0.7, 0.2, 0.1], [0.4, 0.4, 0.2]])
31
+ >>> scores = strategy.compute_scores(probs)
32
+ >>> indices = strategy.select_top_k(scores, k=1)
33
+ >>> len(indices)
34
+ 1
35
+ """
36
+
37
+ @abstractmethod
38
+ def compute_scores(self, probabilities: np.ndarray) -> np.ndarray:
39
+ """Compute uncertainty scores from prediction probabilities.
40
+
41
+ Parameters
42
+ ----------
43
+ probabilities : np.ndarray
44
+ Prediction probabilities with shape (n_samples, n_classes).
45
+ Each row should sum to 1.0.
46
+
47
+ Returns
48
+ -------
49
+ np.ndarray
50
+ Uncertainty scores with shape (n_samples,).
51
+ Higher scores indicate more informative/uncertain items.
52
+
53
+ Examples
54
+ --------
55
+ >>> strategy = UncertaintySampling() # doctest: +SKIP
56
+ >>> probs = np.array([[0.5, 0.5], [0.9, 0.1]]) # doctest: +SKIP
57
+ >>> scores = strategy.compute_scores(probs) # doctest: +SKIP
58
+ >>> scores[0] > scores[1] # First is more uncertain # doctest: +SKIP
59
+ True
60
+ """
61
+ pass
62
+
63
+ def select_top_k(self, scores: np.ndarray, k: int) -> np.ndarray:
64
+ """Select top k items by score.
65
+
66
+ Parameters
67
+ ----------
68
+ scores : np.ndarray
69
+ Uncertainty scores with shape (n_samples,).
70
+ k : int
71
+ Number of items to select.
72
+
73
+ Returns
74
+ -------
75
+ np.ndarray
76
+ Indices of top k items with shape (k,).
77
+ If k > len(scores), returns all indices.
78
+ If k <= 0, returns empty array.
79
+
80
+ Examples
81
+ --------
82
+ >>> strategy = UncertaintySampling()
83
+ >>> scores = np.array([0.5, 0.9, 0.3, 0.7])
84
+ >>> indices = strategy.select_top_k(scores, k=2)
85
+ >>> list(indices)
86
+ [1, 3]
87
+ """
88
+ if k <= 0:
89
+ return np.array([], dtype=int)
90
+
91
+ if k >= len(scores):
92
+ return np.arange(len(scores))
93
+
94
+ # get indices of top k scores (descending order)
95
+ return np.argsort(scores)[-k:][::-1]
96
+
97
+
98
+ class UncertaintySampling(SamplingStrategy):
99
+ """Entropy-based uncertainty sampling.
100
+
101
+ Selects items where the model's prediction entropy is highest,
102
+ indicating maximum uncertainty across all classes.
103
+
104
+ Mathematical definition:
105
+ H(p) = -∑(p_i * log(p_i))
106
+
107
+ where p is the probability distribution over classes.
108
+
109
+ Examples
110
+ --------
111
+ >>> import numpy as np
112
+ >>> strategy = UncertaintySampling()
113
+ >>> # Uniform distribution (high entropy)
114
+ >>> probs = np.array([[0.33, 0.33, 0.34]])
115
+ >>> score = strategy.compute_scores(probs)
116
+ >>> score[0] > 1.0 # High uncertainty
117
+ True
118
+ >>> # Confident prediction (low entropy)
119
+ >>> probs = np.array([[0.9, 0.05, 0.05]])
120
+ >>> score = strategy.compute_scores(probs)
121
+ >>> score[0] < 0.5 # Low uncertainty
122
+ True
123
+ """
124
+
125
+ def compute_scores(self, probabilities: np.ndarray) -> np.ndarray:
126
+ """Compute entropy for each prediction.
127
+
128
+ Parameters
129
+ ----------
130
+ probabilities : np.ndarray
131
+ Prediction probabilities with shape (n_samples, n_classes).
132
+
133
+ Returns
134
+ -------
135
+ np.ndarray
136
+ Entropy scores with shape (n_samples,).
137
+ Higher entropy indicates more uncertainty.
138
+
139
+ Examples
140
+ --------
141
+ >>> strategy = UncertaintySampling()
142
+ >>> probs = np.array([[0.5, 0.5], [0.9, 0.1]])
143
+ >>> scores = strategy.compute_scores(probs)
144
+ >>> scores[0] > scores[1] # Uniform is more uncertain
145
+ True
146
+ """
147
+ # add small epsilon to avoid log(0)
148
+ epsilon = 1e-10
149
+ probs_safe = np.clip(probabilities, epsilon, 1.0)
150
+
151
+ # compute entropy: -∑(p * log(p))
152
+ entropy = -np.sum(probs_safe * np.log(probs_safe), axis=1)
153
+
154
+ return entropy
155
+
156
+
157
+ class MarginSampling(SamplingStrategy):
158
+ """Margin-based uncertainty sampling.
159
+
160
+ Selects items where the margin between the top two predicted classes
161
+ is smallest, indicating uncertainty between the two most likely options.
162
+
163
+ Mathematical definition:
164
+ margin(p) = 1 - (p₁ - p₂)
165
+
166
+ where p₁ and p₂ are the highest and second-highest probabilities.
167
+
168
+ Examples
169
+ --------
170
+ >>> import numpy as np
171
+ >>> strategy = MarginSampling()
172
+ >>> # Small margin (uncertain)
173
+ >>> probs = np.array([[0.51, 0.49, 0.0]])
174
+ >>> score = strategy.compute_scores(probs)
175
+ >>> score[0] > 0.95 # High uncertainty
176
+ True
177
+ >>> # Large margin (confident)
178
+ >>> probs = np.array([[0.9, 0.05, 0.05]])
179
+ >>> score = strategy.compute_scores(probs)
180
+ >>> score[0] < 0.2 # Low uncertainty
181
+ True
182
+ """
183
+
184
+ def compute_scores(self, probabilities: np.ndarray) -> np.ndarray:
185
+ """Compute margin scores for each prediction.
186
+
187
+ Parameters
188
+ ----------
189
+ probabilities : np.ndarray
190
+ Prediction probabilities with shape (n_samples, n_classes).
191
+
192
+ Returns
193
+ -------
194
+ np.ndarray
195
+ Margin scores with shape (n_samples,).
196
+ Higher scores indicate smaller margin (more uncertainty).
197
+
198
+ Examples
199
+ --------
200
+ >>> strategy = MarginSampling()
201
+ >>> probs = np.array([[0.6, 0.3, 0.1], [0.8, 0.15, 0.05]])
202
+ >>> scores = strategy.compute_scores(probs)
203
+ >>> scores[0] > scores[1] # First has smaller margin
204
+ True
205
+ """
206
+ # sort probabilities in descending order
207
+ sorted_probs = np.sort(probabilities, axis=1)
208
+
209
+ # get top 2 probabilities
210
+ top1 = sorted_probs[:, -1]
211
+ top2 = sorted_probs[:, -2]
212
+
213
+ # compute margin: 1 - (p1 - p2)
214
+ margin = 1.0 - (top1 - top2)
215
+
216
+ return margin
217
+
218
+
219
+ class LeastConfidenceSampling(SamplingStrategy):
220
+ """Least confidence sampling.
221
+
222
+ Selects items where the model is least confident, measured as
223
+ 1 minus the maximum predicted probability.
224
+
225
+ Mathematical definition:
226
+ lc(p) = 1 - max(p)
227
+
228
+ where p is the probability distribution over classes.
229
+
230
+ Examples
231
+ --------
232
+ >>> import numpy as np
233
+ >>> strategy = LeastConfidenceSampling()
234
+ >>> # Low confidence
235
+ >>> probs = np.array([[0.4, 0.3, 0.3]])
236
+ >>> score = strategy.compute_scores(probs)
237
+ >>> score[0] == 0.6 # 1 - 0.4
238
+ True
239
+ >>> # High confidence
240
+ >>> probs = np.array([[0.95, 0.03, 0.02]])
241
+ >>> score = strategy.compute_scores(probs)
242
+ >>> score[0] == 0.05 # 1 - 0.95
243
+ True
244
+ """
245
+
246
+ def compute_scores(self, probabilities: np.ndarray) -> np.ndarray:
247
+ """Compute least confidence scores for each prediction.
248
+
249
+ Parameters
250
+ ----------
251
+ probabilities : np.ndarray
252
+ Prediction probabilities with shape (n_samples, n_classes).
253
+
254
+ Returns
255
+ -------
256
+ np.ndarray
257
+ Least confidence scores with shape (n_samples,).
258
+ Higher scores indicate lower confidence (more uncertainty).
259
+
260
+ Examples
261
+ --------
262
+ >>> strategy = LeastConfidenceSampling()
263
+ >>> probs = np.array([[0.5, 0.5], [0.9, 0.1]])
264
+ >>> scores = strategy.compute_scores(probs)
265
+ >>> scores[0] > scores[1] # First is less confident
266
+ True
267
+ """
268
+ # get maximum probability for each sample
269
+ max_probs = np.max(probabilities, axis=1)
270
+
271
+ # compute least confidence: 1 - max(p)
272
+ least_confidence = 1.0 - max_probs
273
+
274
+ return least_confidence
275
+
276
+
277
+ class RandomSampling(SamplingStrategy):
278
+ """Random sampling baseline.
279
+
280
+ Selects items randomly, serving as a baseline for comparison
281
+ with uncertainty-based methods. Uses seeded random number generation
282
+ for reproducibility.
283
+
284
+ Parameters
285
+ ----------
286
+ seed : int | None
287
+ Random seed for reproducibility. If None, uses non-deterministic seed.
288
+
289
+ Attributes
290
+ ----------
291
+ rng : np.random.Generator
292
+ Random number generator.
293
+
294
+ Examples
295
+ --------
296
+ >>> import numpy as np
297
+ >>> strategy = RandomSampling(seed=42)
298
+ >>> probs = np.array([[0.9, 0.1], [0.5, 0.5]])
299
+ >>> scores = strategy.compute_scores(probs)
300
+ >>> len(scores) == 2
301
+ True
302
+ >>> # Scores are random, not based on probabilities
303
+ >>> strategy2 = RandomSampling(seed=42)
304
+ >>> scores2 = strategy2.compute_scores(probs)
305
+ >>> np.allclose(scores, scores2) # Same seed gives same results
306
+ True
307
+ """
308
+
309
+ def __init__(self, seed: int | None = None) -> None:
310
+ self.rng = np.random.default_rng(seed)
311
+
312
+ def compute_scores(self, probabilities: np.ndarray) -> np.ndarray:
313
+ """Generate random scores for each item.
314
+
315
+ Parameters
316
+ ----------
317
+ probabilities : np.ndarray
318
+ Prediction probabilities with shape (n_samples, n_classes).
319
+ Not used in random sampling, but required by interface.
320
+
321
+ Returns
322
+ -------
323
+ np.ndarray
324
+ Random scores with shape (n_samples,).
325
+
326
+ Examples
327
+ --------
328
+ >>> strategy = RandomSampling(seed=123)
329
+ >>> probs = np.array([[0.9, 0.1], [0.1, 0.9]])
330
+ >>> scores = strategy.compute_scores(probs)
331
+ >>> len(scores) == 2
332
+ True
333
+ >>> 0.0 <= scores[0] <= 1.0
334
+ True
335
+ """
336
+ n_samples = probabilities.shape[0]
337
+ return self.rng.random(n_samples)
338
+
339
+
340
+ # Type alias for strategy methods
341
+ StrategyMethod = Literal["entropy", "margin", "least_confidence", "random"]
342
+
343
+
344
+ def create_strategy(
345
+ method: StrategyMethod, seed: int | None = None
346
+ ) -> SamplingStrategy:
347
+ """Create a sampling strategy instance.
348
+
349
+ Parameters
350
+ ----------
351
+ method : StrategyMethod
352
+ Strategy method name ("entropy", "margin", "least_confidence", "random").
353
+ seed : int | None
354
+ Random seed for random strategy. Ignored for other strategies.
355
+
356
+ Returns
357
+ -------
358
+ SamplingStrategy
359
+ Instantiated sampling strategy.
360
+
361
+ Raises
362
+ ------
363
+ ValueError
364
+ If method is not recognized.
365
+
366
+ Examples
367
+ --------
368
+ >>> strategy = create_strategy("entropy")
369
+ >>> isinstance(strategy, UncertaintySampling)
370
+ True
371
+ >>> strategy = create_strategy("margin")
372
+ >>> isinstance(strategy, MarginSampling)
373
+ True
374
+ >>> strategy = create_strategy("least_confidence")
375
+ >>> isinstance(strategy, LeastConfidenceSampling)
376
+ True
377
+ >>> strategy = create_strategy("random", seed=42)
378
+ >>> isinstance(strategy, RandomSampling)
379
+ True
380
+ """
381
+ match method:
382
+ case "entropy":
383
+ return UncertaintySampling()
384
+ case "margin":
385
+ return MarginSampling()
386
+ case "least_confidence":
387
+ return LeastConfidenceSampling()
388
+ case "random":
389
+ return RandomSampling(seed=seed)
390
+ case _:
391
+ raise ValueError(f"Unknown sampling method: {method}")
@@ -0,0 +1,26 @@
1
+ """Training framework adapters.
2
+
3
+ Provides trainer implementations for HuggingFace Transformers and PyTorch
4
+ Lightning. All trainers implement the BaseTrainer interface.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from bead.active_learning.trainers.base import BaseTrainer, ModelMetadata
10
+ from bead.active_learning.trainers.huggingface import HuggingFaceTrainer
11
+ from bead.active_learning.trainers.lightning import PyTorchLightningTrainer
12
+ from bead.active_learning.trainers.registry import (
13
+ get_trainer,
14
+ list_trainers,
15
+ register_trainer,
16
+ )
17
+
18
+ __all__ = [
19
+ "BaseTrainer",
20
+ "HuggingFaceTrainer",
21
+ "ModelMetadata",
22
+ "PyTorchLightningTrainer",
23
+ "get_trainer",
24
+ "list_trainers",
25
+ "register_trainer",
26
+ ]
@@ -0,0 +1,210 @@
1
+ """Base trainer interface for model training.
2
+
3
+ This module provides the abstract base class for all trainers and the
4
+ ModelMetadata model for tracking training results.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from abc import ABC, abstractmethod
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING
12
+
13
+ from bead.data.base import BeadBaseModel
14
+
15
+ if TYPE_CHECKING:
16
+ from datasets import Dataset
17
+ from transformers import PreTrainedModel
18
+
19
+
20
+ class ModelMetadata(BeadBaseModel):
21
+ """Training metadata.
22
+
23
+ Parameters
24
+ ----------
25
+ model_name : str
26
+ Model identifier.
27
+ framework : str
28
+ Training framework ("huggingface" or "pytorch_lightning").
29
+ training_config : dict[str, str | int | float | bool | Path | None]
30
+ Training configuration used.
31
+ training_data_path : Path
32
+ Path to training data.
33
+ eval_data_path : Path | None
34
+ Path to eval data if used.
35
+ metrics : dict[str, float]
36
+ Final evaluation metrics.
37
+ best_checkpoint : Path | None
38
+ Path to best checkpoint.
39
+ training_time : float
40
+ Total training time in seconds.
41
+ training_timestamp : str
42
+ ISO 8601 timestamp when training completed.
43
+
44
+ Attributes
45
+ ----------
46
+ model_name : str
47
+ Model identifier.
48
+ framework : str
49
+ Training framework ("huggingface" or "pytorch_lightning").
50
+ training_config : dict[str, str | int | float | bool | Path | None]
51
+ Training configuration used.
52
+ training_data_path : Path
53
+ Path to training data.
54
+ eval_data_path : Path | None
55
+ Path to eval data if used.
56
+ metrics : dict[str, float]
57
+ Final evaluation metrics.
58
+ best_checkpoint : Path | None
59
+ Path to best checkpoint.
60
+ training_time : float
61
+ Total training time in seconds.
62
+ training_timestamp : str
63
+ ISO 8601 timestamp when training completed.
64
+
65
+ Examples
66
+ --------
67
+ >>> from pathlib import Path
68
+ >>> metadata = ModelMetadata(
69
+ ... model_name="bert-base-uncased",
70
+ ... framework="huggingface",
71
+ ... training_config={"epochs": 3},
72
+ ... training_data_path=Path("train.json"),
73
+ ... metrics={"accuracy": 0.95},
74
+ ... training_time=120.5,
75
+ ... training_timestamp="2025-01-17T00:00:00+00:00"
76
+ ... )
77
+ >>> metadata.framework
78
+ 'huggingface'
79
+ >>> metadata.metrics["accuracy"]
80
+ 0.95
81
+ """
82
+
83
+ model_name: str
84
+ framework: str
85
+ training_config: dict[str, str | int | float | bool | Path | None]
86
+ training_data_path: Path
87
+ eval_data_path: Path | None = None
88
+ metrics: dict[str, float]
89
+ best_checkpoint: Path | None = None
90
+ training_time: float
91
+ training_timestamp: str
92
+
93
+
94
+ class BaseTrainer(ABC):
95
+ """Base trainer interface.
96
+
97
+ All trainers must implement the train, save_model, and load_model methods.
98
+ This provides a consistent interface across different training frameworks.
99
+
100
+ Parameters
101
+ ----------
102
+ config : dict[str, str | int | float | bool | Path | None] | BeadBaseModel
103
+ Training configuration (framework-specific).
104
+
105
+ Attributes
106
+ ----------
107
+ config : dict[str, str | int | float | bool | Path | None] | BeadBaseModel
108
+ Training configuration.
109
+
110
+ Examples
111
+ --------
112
+ >>> class MyTrainer(BaseTrainer):
113
+ ... def train(self, train_data, eval_data=None):
114
+ ... return ModelMetadata(
115
+ ... model_name="test",
116
+ ... framework="custom",
117
+ ... training_config={},
118
+ ... training_data_path=Path("train.json"),
119
+ ... metrics={},
120
+ ... training_time=0.0,
121
+ ... training_timestamp="2025-01-17T00:00:00+00:00"
122
+ ... )
123
+ ... def save_model(self, output_dir, metadata):
124
+ ... pass
125
+ ... def load_model(self, model_dir):
126
+ ... return None
127
+ >>> trainer = MyTrainer(config={})
128
+ >>> trainer.config
129
+ {}
130
+ """
131
+
132
+ def __init__(
133
+ self, config: dict[str, str | int | float | bool | Path | None] | BeadBaseModel
134
+ ) -> None:
135
+ self.config = config
136
+
137
+ @abstractmethod
138
+ def train(
139
+ self,
140
+ train_data: Dataset
141
+ | dict[str, str | int | float | bool | None]
142
+ | list[dict[str, str | int | float | bool | None]],
143
+ eval_data: Dataset
144
+ | dict[str, str | int | float | bool | None]
145
+ | list[dict[str, str | int | float | bool | None]]
146
+ | None = None,
147
+ ) -> ModelMetadata:
148
+ """Train model and return metadata.
149
+
150
+ Parameters
151
+ ----------
152
+ train_data : Dataset | dict | list
153
+ Training dataset (framework-specific format).
154
+ eval_data : Dataset | dict | list | None
155
+ Evaluation dataset (framework-specific format).
156
+
157
+ Returns
158
+ -------
159
+ ModelMetadata
160
+ Metadata about the training run.
161
+
162
+ Examples
163
+ --------
164
+ >>> trainer = MyTrainer(config={}) # doctest: +SKIP
165
+ >>> metadata = trainer.train(train_dataset) # doctest: +SKIP
166
+ >>> metadata.framework # doctest: +SKIP
167
+ 'custom'
168
+ """
169
+ pass
170
+
171
+ @abstractmethod
172
+ def save_model(self, output_dir: Path, metadata: ModelMetadata) -> None:
173
+ """Save model and metadata to directory.
174
+
175
+ Parameters
176
+ ----------
177
+ output_dir : Path
178
+ Directory to save model and metadata.
179
+ metadata : ModelMetadata
180
+ Training metadata to save.
181
+
182
+ Examples
183
+ --------
184
+ >>> trainer = MyTrainer(config={}) # doctest: +SKIP
185
+ >>> trainer.save_model(Path("output"), metadata) # doctest: +SKIP
186
+ """
187
+ pass
188
+
189
+ @abstractmethod
190
+ def load_model(
191
+ self, model_dir: Path
192
+ ) -> PreTrainedModel | dict[str, str | int | float | bool | None] | BeadBaseModel:
193
+ """Load model from directory.
194
+
195
+ Parameters
196
+ ----------
197
+ model_dir : Path
198
+ Directory containing saved model.
199
+
200
+ Returns
201
+ -------
202
+ PreTrainedModel | dict[str, str | int | float | bool | None] | BeadBaseModel
203
+ Loaded model (framework-specific type).
204
+
205
+ Examples
206
+ --------
207
+ >>> trainer = MyTrainer(config={}) # doctest: +SKIP
208
+ >>> model = trainer.load_model(Path("saved_model")) # doctest: +SKIP
209
+ """
210
+ pass