bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,513 @@
1
+ """Active learning commands for bead CLI.
2
+
3
+ This module provides commands for active learning workflows including item
4
+ selection and convergence monitoring.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import re
11
+ from pathlib import Path
12
+
13
+ import click
14
+ from rich.console import Console
15
+ from rich.progress import Progress, SpinnerColumn, TextColumn
16
+ from rich.table import Table
17
+
18
+ from bead.cli.active_learning_commands import run, select_items
19
+ from bead.cli.utils import print_error, print_info, print_success
20
+ from bead.evaluation.convergence import ConvergenceDetector
21
+ from bead.evaluation.interannotator import InterAnnotatorMetrics
22
+
23
+ console = Console()
24
+
25
+
26
+ @click.group()
27
+ def active_learning() -> None:
28
+ r"""Active learning commands.
29
+
30
+ Commands for convergence detection and active learning workflows.
31
+
32
+ \b
33
+ AVAILABLE COMMANDS:
34
+ check-convergence Check if model converged to human agreement
35
+ monitor-convergence Monitor convergence over multiple iterations
36
+
37
+ \b
38
+ Examples:
39
+ # Check convergence
40
+ $ bead active-learning check-convergence \\
41
+ --predictions predictions.jsonl \\
42
+ --human-labels labels.jsonl \\
43
+ --metric krippendorff_alpha \\
44
+ --threshold 0.85
45
+ """
46
+
47
+
48
+ @click.command()
49
+ @click.option(
50
+ "--predictions",
51
+ type=click.Path(exists=True, path_type=Path),
52
+ required=True,
53
+ help="Path to model predictions file (JSONL with 'prediction' field)",
54
+ )
55
+ @click.option(
56
+ "--human-labels",
57
+ type=click.Path(exists=True, path_type=Path),
58
+ required=True,
59
+ help="Path to human labels file (JSONL with 'label' field per rater)",
60
+ )
61
+ @click.option(
62
+ "--metric",
63
+ type=click.Choice(
64
+ ["krippendorff_alpha", "fleiss_kappa", "cohens_kappa", "percentage_agreement"],
65
+ case_sensitive=False,
66
+ ),
67
+ default="krippendorff_alpha",
68
+ help="Agreement metric to use (default: krippendorff_alpha)",
69
+ )
70
+ @click.option(
71
+ "--threshold",
72
+ type=float,
73
+ default=0.80,
74
+ help="Convergence threshold (default: 0.80)",
75
+ )
76
+ @click.option(
77
+ "--min-iterations",
78
+ type=int,
79
+ default=1,
80
+ help="Minimum iterations before checking convergence (default: 1)",
81
+ )
82
+ @click.pass_context
83
+ def check_convergence(
84
+ ctx: click.Context,
85
+ predictions: Path,
86
+ human_labels: Path,
87
+ metric: str,
88
+ threshold: float,
89
+ min_iterations: int,
90
+ ) -> None:
91
+ r"""Check if model has converged to human agreement level.
92
+
93
+ Compares model predictions with human labels using inter-annotator
94
+ agreement metrics to determine convergence. This is a FULLY IMPLEMENTED
95
+ command that uses actual ConvergenceDetector from bead.evaluation.
96
+
97
+ Parameters
98
+ ----------
99
+ ctx : click.Context
100
+ Click context object.
101
+ predictions : Path
102
+ Path to model predictions file.
103
+ human_labels : Path
104
+ Path to human labels file.
105
+ metric : str
106
+ Agreement metric name.
107
+ threshold : float
108
+ Convergence threshold.
109
+ min_iterations : int
110
+ Minimum iterations before allowing convergence.
111
+
112
+ Examples
113
+ --------
114
+ $ bead active-learning check-convergence \\
115
+ --predictions predictions.jsonl \\
116
+ --human-labels labels.jsonl \\
117
+ --metric krippendorff_alpha \\
118
+ --threshold 0.85
119
+
120
+ $ bead active-learning check-convergence \\
121
+ --predictions predictions.jsonl \\
122
+ --human-labels labels.jsonl \\
123
+ --metric fleiss_kappa \\
124
+ --threshold 0.75
125
+ """
126
+ try:
127
+ console.rule("[bold]Convergence Check[/bold]")
128
+
129
+ # Load predictions
130
+ print_info(f"Loading predictions from {predictions}")
131
+ with open(predictions, encoding="utf-8") as f:
132
+ pred_records = [json.loads(line) for line in f if line.strip()]
133
+
134
+ model_predictions = [r["prediction"] for r in pred_records]
135
+ print_success(f"Loaded {len(model_predictions)} predictions")
136
+
137
+ # Load human labels (organized by rater)
138
+ print_info(f"Loading human labels from {human_labels}")
139
+ with open(human_labels, encoding="utf-8") as f:
140
+ label_records = [json.loads(line) for line in f if line.strip()]
141
+
142
+ # Organize by rater
143
+ rater_labels: dict[str, list[int | str | float]] = {}
144
+ for record in label_records:
145
+ rater_id = str(record.get("rater_id", "rater_1"))
146
+ label = record["label"]
147
+ if rater_id not in rater_labels:
148
+ rater_labels[rater_id] = []
149
+ rater_labels[rater_id].append(label)
150
+
151
+ n_raters = len(rater_labels)
152
+ print_success(f"Loaded labels from {n_raters} raters")
153
+
154
+ # Create convergence detector
155
+ print_info(f"Computing {metric}...")
156
+ detector = ConvergenceDetector(
157
+ human_agreement_metric=metric,
158
+ convergence_threshold=threshold,
159
+ min_iterations=min_iterations,
160
+ statistical_test=True,
161
+ )
162
+
163
+ # Compute human baseline
164
+ with Progress(
165
+ SpinnerColumn(),
166
+ TextColumn("[progress.description]{task.description}"),
167
+ console=console,
168
+ ) as progress:
169
+ progress.add_task("Computing human agreement baseline...", total=None)
170
+ human_baseline = detector.compute_human_baseline(rater_labels)
171
+
172
+ print_success(f"Human baseline: {human_baseline:.4f}")
173
+
174
+ # Add model as another "rater" for comparison
175
+ all_raters = {**rater_labels, "model": model_predictions}
176
+
177
+ # Compute agreement including model
178
+ if metric == "krippendorff_alpha":
179
+ model_agreement = InterAnnotatorMetrics.krippendorff_alpha(
180
+ all_raters, metric="nominal"
181
+ )
182
+ else:
183
+ # For other metrics, compare model directly to human majority vote
184
+ # Get majority human label for each item
185
+ n_items = len(model_predictions)
186
+ human_votes = []
187
+ for i in range(n_items):
188
+ votes_for_item = [rater_labels[r][i] for r in rater_labels]
189
+ # Simple majority vote
190
+ majority = max(set(votes_for_item), key=votes_for_item.count)
191
+ human_votes.append(majority)
192
+
193
+ # Compute agreement between model and human majority
194
+ model_human_pairs = zip(model_predictions, human_votes, strict=True)
195
+ agreements = sum(p == h for p, h in model_human_pairs)
196
+ model_agreement = agreements / len(model_predictions)
197
+
198
+ print_success(f"Model agreement: {model_agreement:.4f}")
199
+
200
+ # Check convergence
201
+ converged = detector.check_convergence(
202
+ model_accuracy=model_agreement, iteration=min_iterations
203
+ )
204
+
205
+ # Display results
206
+ table = Table(title="Convergence Results")
207
+ table.add_column("Metric", style="cyan")
208
+ table.add_column("Value", style="green", justify="right")
209
+
210
+ table.add_row("Agreement Metric", metric)
211
+ table.add_row("Human Baseline", f"{human_baseline:.4f}")
212
+ table.add_row("Model Agreement", f"{model_agreement:.4f}")
213
+ table.add_row("Threshold", f"{threshold:.4f}")
214
+ table.add_row("Converged", "✓ Yes" if converged else "✗ No")
215
+
216
+ if converged:
217
+ table.add_row(
218
+ "Status", "[green]Model has converged to human agreement[/green]"
219
+ )
220
+ else:
221
+ gap = threshold - model_agreement
222
+ table.add_row(
223
+ "Status", f"[yellow]Need {gap:.4f} more to reach threshold[/yellow]"
224
+ )
225
+
226
+ console.print(table)
227
+
228
+ # Exit with appropriate code
229
+ if converged:
230
+ print_success("Convergence achieved!")
231
+ ctx.exit(0)
232
+ else:
233
+ print_info("Not yet converged. Continue training.")
234
+ ctx.exit(1)
235
+
236
+ except FileNotFoundError as e:
237
+ print_error(f"File not found: {e}")
238
+ ctx.exit(1)
239
+ except KeyError as e:
240
+ print_error(f"Missing required field in data: {e}")
241
+ ctx.exit(1)
242
+ except json.JSONDecodeError as e:
243
+ print_error(f"Invalid JSON: {e}")
244
+ ctx.exit(1)
245
+ except Exception as e:
246
+ print_error(f"Convergence check failed: {e}")
247
+ ctx.exit(1)
248
+
249
+
250
+ @click.command()
251
+ @click.option(
252
+ "--checkpoint-dir",
253
+ type=click.Path(exists=True, path_type=Path),
254
+ required=True,
255
+ help="Directory containing model checkpoints",
256
+ )
257
+ @click.option(
258
+ "--human-labels",
259
+ type=click.Path(exists=True, path_type=Path),
260
+ required=True,
261
+ help="Path to human labels file (JSONL with 'label' field per rater)",
262
+ )
263
+ @click.option(
264
+ "--metric",
265
+ type=click.Choice(
266
+ ["krippendorff_alpha", "fleiss_kappa", "cohens_kappa", "percentage_agreement"],
267
+ case_sensitive=False,
268
+ ),
269
+ default="krippendorff_alpha",
270
+ help="Agreement metric to use (default: krippendorff_alpha)",
271
+ )
272
+ @click.option(
273
+ "--threshold",
274
+ type=float,
275
+ default=0.80,
276
+ help="Convergence threshold (default: 0.80)",
277
+ )
278
+ @click.option(
279
+ "--min-iterations",
280
+ type=int,
281
+ default=1,
282
+ help="Minimum iterations before checking convergence (default: 1)",
283
+ )
284
+ @click.option(
285
+ "--output",
286
+ "-o",
287
+ type=click.Path(path_type=Path),
288
+ default=None,
289
+ help="Output file for convergence report (default: stdout)",
290
+ )
291
+ @click.pass_context
292
+ def monitor_convergence(
293
+ ctx: click.Context,
294
+ checkpoint_dir: Path,
295
+ human_labels: Path,
296
+ metric: str,
297
+ threshold: float,
298
+ min_iterations: int,
299
+ output: Path | None,
300
+ ) -> None:
301
+ r"""Monitor convergence over multiple iterations.
302
+
303
+ Loads model checkpoints from a directory and checks convergence
304
+ against human labels for each iteration. Produces a convergence
305
+ report showing progress over time.
306
+
307
+ Parameters
308
+ ----------
309
+ ctx : click.Context
310
+ Click context object.
311
+ checkpoint_dir : Path
312
+ Directory containing model checkpoints.
313
+ human_labels : Path
314
+ Path to human labels file.
315
+ metric : str
316
+ Agreement metric name.
317
+ threshold : float
318
+ Convergence threshold.
319
+ min_iterations : int
320
+ Minimum iterations before allowing convergence.
321
+ output : Path | None
322
+ Output file path (None for stdout).
323
+
324
+ Examples
325
+ --------
326
+ $ bead active-learning monitor-convergence \\
327
+ --checkpoint-dir models/checkpoints \\
328
+ --human-labels labels.jsonl \\
329
+ --metric krippendorff_alpha \\
330
+ --threshold 0.85
331
+
332
+ $ bead active-learning monitor-convergence \\
333
+ --checkpoint-dir models/checkpoints \\
334
+ --human-labels labels.jsonl \\
335
+ --output convergence_report.json
336
+ """
337
+ try:
338
+ console.rule("[bold]Convergence Monitoring[/bold]")
339
+
340
+ # Load human labels
341
+ print_info(f"Loading human labels from {human_labels}")
342
+ with open(human_labels, encoding="utf-8") as f:
343
+ label_records = [json.loads(line) for line in f if line.strip()]
344
+
345
+ # Organize by rater
346
+ rater_labels: dict[str, list[int | str | float]] = {}
347
+ for record in label_records:
348
+ rater_id = str(record.get("rater_id", "rater_1"))
349
+ label = record["label"]
350
+ if rater_id not in rater_labels:
351
+ rater_labels[rater_id] = []
352
+ rater_labels[rater_id].append(label)
353
+
354
+ n_raters = len(rater_labels)
355
+ print_success(f"Loaded labels from {n_raters} raters")
356
+
357
+ # Create convergence detector
358
+ detector = ConvergenceDetector(
359
+ human_agreement_metric=metric,
360
+ convergence_threshold=threshold,
361
+ min_iterations=min_iterations,
362
+ statistical_test=True,
363
+ )
364
+
365
+ # Compute human baseline
366
+ with Progress(
367
+ SpinnerColumn(),
368
+ TextColumn("[progress.description]{task.description}"),
369
+ console=console,
370
+ ) as progress:
371
+ progress.add_task("Computing human agreement baseline...", total=None)
372
+ human_baseline = detector.compute_human_baseline(rater_labels)
373
+
374
+ print_success(f"Human baseline: {human_baseline:.4f}")
375
+
376
+ # Find checkpoint files
377
+ checkpoint_files = sorted(checkpoint_dir.glob("**/predictions*.jsonl"))
378
+ if not checkpoint_files:
379
+ print_error(f"No prediction files found in {checkpoint_dir}")
380
+ ctx.exit(1)
381
+
382
+ print_info(f"Found {len(checkpoint_files)} checkpoint(s)")
383
+
384
+ # Process each checkpoint
385
+ convergence_history: list[dict[str, str | int | float | bool]] = []
386
+ for checkpoint_file in checkpoint_files:
387
+ iteration_num = _extract_iteration_number(checkpoint_file)
388
+ if iteration_num is None:
389
+ continue
390
+
391
+ print_info(f"Processing iteration {iteration_num}...")
392
+
393
+ # Load predictions
394
+ with open(checkpoint_file, encoding="utf-8") as f:
395
+ pred_records = [json.loads(line) for line in f if line.strip()]
396
+
397
+ model_predictions = [r["prediction"] for r in pred_records]
398
+
399
+ # Compute model agreement
400
+ all_raters = {**rater_labels, "model": model_predictions}
401
+ if metric == "krippendorff_alpha":
402
+ model_agreement = InterAnnotatorMetrics.krippendorff_alpha(
403
+ all_raters, metric="nominal"
404
+ )
405
+ else:
406
+ # For other metrics, compare to human majority
407
+ n_items = len(model_predictions)
408
+ human_votes = []
409
+ for i in range(n_items):
410
+ votes_for_item = [rater_labels[r][i] for r in rater_labels]
411
+ majority = max(set(votes_for_item), key=votes_for_item.count)
412
+ human_votes.append(majority)
413
+
414
+ agreements = sum(
415
+ p == h for p, h in zip(model_predictions, human_votes, strict=True)
416
+ )
417
+ model_agreement = agreements / len(model_predictions)
418
+
419
+ # Check convergence
420
+ converged = detector.check_convergence(
421
+ model_accuracy=model_agreement, iteration=iteration_num
422
+ )
423
+
424
+ convergence_history.append(
425
+ {
426
+ "iteration": iteration_num,
427
+ "model_agreement": model_agreement,
428
+ "human_baseline": human_baseline,
429
+ "converged": converged,
430
+ "gap": human_baseline - model_agreement,
431
+ }
432
+ )
433
+
434
+ # Display results
435
+ table = Table(title="Convergence History")
436
+ table.add_column("Iteration", style="cyan")
437
+ table.add_column("Model Agreement", style="green", justify="right")
438
+ table.add_column("Human Baseline", style="blue", justify="right")
439
+ table.add_column("Gap", style="yellow", justify="right")
440
+ table.add_column("Status", style="magenta")
441
+
442
+ for record in convergence_history:
443
+ status = "✓ Converged" if record["converged"] else "✗ Not converged"
444
+ table.add_row(
445
+ str(record["iteration"]),
446
+ f"{record['model_agreement']:.4f}",
447
+ f"{record['human_baseline']:.4f}",
448
+ f"{record['gap']:.4f}",
449
+ status,
450
+ )
451
+
452
+ console.print(table)
453
+
454
+ # Write output if specified
455
+ if output:
456
+ with open(output, "w", encoding="utf-8") as f:
457
+ json.dump(convergence_history, f, indent=2)
458
+ print_success(f"Report written to {output}")
459
+
460
+ # Check if any iteration converged
461
+ any_converged = any(r["converged"] for r in convergence_history)
462
+ if any_converged:
463
+ print_success("Convergence achieved in at least one iteration!")
464
+ ctx.exit(0)
465
+ else:
466
+ print_info("No convergence detected yet. Continue training.")
467
+ ctx.exit(1)
468
+
469
+ except FileNotFoundError as e:
470
+ print_error(f"File not found: {e}")
471
+ ctx.exit(1)
472
+ except KeyError as e:
473
+ print_error(f"Missing required field in data: {e}")
474
+ ctx.exit(1)
475
+ except json.JSONDecodeError as e:
476
+ print_error(f"Invalid JSON: {e}")
477
+ ctx.exit(1)
478
+ except Exception as e:
479
+ print_error(f"Convergence monitoring failed: {e}")
480
+ ctx.exit(1)
481
+
482
+
483
+ def _extract_iteration_number(path: Path) -> int | None:
484
+ """Extract iteration number from checkpoint file path.
485
+
486
+ Parameters
487
+ ----------
488
+ path : Path
489
+ Checkpoint file path.
490
+
491
+ Returns
492
+ -------
493
+ int | None
494
+ Iteration number if found, None otherwise.
495
+ """
496
+ # Try to find iteration number in filename
497
+ match = re.search(r"iteration[_-]?(\d+)", path.stem, re.IGNORECASE)
498
+ if match:
499
+ return int(match.group(1))
500
+
501
+ # Try to find number at end of filename
502
+ match = re.search(r"(\d+)", path.stem)
503
+ if match:
504
+ return int(match.group(1))
505
+
506
+ return None
507
+
508
+
509
+ # Register commands
510
+ active_learning.add_command(check_convergence)
511
+ active_learning.add_command(monitor_convergence)
512
+ active_learning.add_command(run)
513
+ active_learning.add_command(select_items)