bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
"""Active learning commands for bead CLI.
|
|
2
|
+
|
|
3
|
+
This module provides commands for active learning workflows including item
|
|
4
|
+
selection and convergence monitoring.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import re
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import click
|
|
14
|
+
from rich.console import Console
|
|
15
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
16
|
+
from rich.table import Table
|
|
17
|
+
|
|
18
|
+
from bead.cli.active_learning_commands import run, select_items
|
|
19
|
+
from bead.cli.utils import print_error, print_info, print_success
|
|
20
|
+
from bead.evaluation.convergence import ConvergenceDetector
|
|
21
|
+
from bead.evaluation.interannotator import InterAnnotatorMetrics
|
|
22
|
+
|
|
23
|
+
console = Console()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@click.group()
|
|
27
|
+
def active_learning() -> None:
|
|
28
|
+
r"""Active learning commands.
|
|
29
|
+
|
|
30
|
+
Commands for convergence detection and active learning workflows.
|
|
31
|
+
|
|
32
|
+
\b
|
|
33
|
+
AVAILABLE COMMANDS:
|
|
34
|
+
check-convergence Check if model converged to human agreement
|
|
35
|
+
monitor-convergence Monitor convergence over multiple iterations
|
|
36
|
+
|
|
37
|
+
\b
|
|
38
|
+
Examples:
|
|
39
|
+
# Check convergence
|
|
40
|
+
$ bead active-learning check-convergence \\
|
|
41
|
+
--predictions predictions.jsonl \\
|
|
42
|
+
--human-labels labels.jsonl \\
|
|
43
|
+
--metric krippendorff_alpha \\
|
|
44
|
+
--threshold 0.85
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@click.command()
|
|
49
|
+
@click.option(
|
|
50
|
+
"--predictions",
|
|
51
|
+
type=click.Path(exists=True, path_type=Path),
|
|
52
|
+
required=True,
|
|
53
|
+
help="Path to model predictions file (JSONL with 'prediction' field)",
|
|
54
|
+
)
|
|
55
|
+
@click.option(
|
|
56
|
+
"--human-labels",
|
|
57
|
+
type=click.Path(exists=True, path_type=Path),
|
|
58
|
+
required=True,
|
|
59
|
+
help="Path to human labels file (JSONL with 'label' field per rater)",
|
|
60
|
+
)
|
|
61
|
+
@click.option(
|
|
62
|
+
"--metric",
|
|
63
|
+
type=click.Choice(
|
|
64
|
+
["krippendorff_alpha", "fleiss_kappa", "cohens_kappa", "percentage_agreement"],
|
|
65
|
+
case_sensitive=False,
|
|
66
|
+
),
|
|
67
|
+
default="krippendorff_alpha",
|
|
68
|
+
help="Agreement metric to use (default: krippendorff_alpha)",
|
|
69
|
+
)
|
|
70
|
+
@click.option(
|
|
71
|
+
"--threshold",
|
|
72
|
+
type=float,
|
|
73
|
+
default=0.80,
|
|
74
|
+
help="Convergence threshold (default: 0.80)",
|
|
75
|
+
)
|
|
76
|
+
@click.option(
|
|
77
|
+
"--min-iterations",
|
|
78
|
+
type=int,
|
|
79
|
+
default=1,
|
|
80
|
+
help="Minimum iterations before checking convergence (default: 1)",
|
|
81
|
+
)
|
|
82
|
+
@click.pass_context
|
|
83
|
+
def check_convergence(
|
|
84
|
+
ctx: click.Context,
|
|
85
|
+
predictions: Path,
|
|
86
|
+
human_labels: Path,
|
|
87
|
+
metric: str,
|
|
88
|
+
threshold: float,
|
|
89
|
+
min_iterations: int,
|
|
90
|
+
) -> None:
|
|
91
|
+
r"""Check if model has converged to human agreement level.
|
|
92
|
+
|
|
93
|
+
Compares model predictions with human labels using inter-annotator
|
|
94
|
+
agreement metrics to determine convergence. This is a FULLY IMPLEMENTED
|
|
95
|
+
command that uses actual ConvergenceDetector from bead.evaluation.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
ctx : click.Context
|
|
100
|
+
Click context object.
|
|
101
|
+
predictions : Path
|
|
102
|
+
Path to model predictions file.
|
|
103
|
+
human_labels : Path
|
|
104
|
+
Path to human labels file.
|
|
105
|
+
metric : str
|
|
106
|
+
Agreement metric name.
|
|
107
|
+
threshold : float
|
|
108
|
+
Convergence threshold.
|
|
109
|
+
min_iterations : int
|
|
110
|
+
Minimum iterations before allowing convergence.
|
|
111
|
+
|
|
112
|
+
Examples
|
|
113
|
+
--------
|
|
114
|
+
$ bead active-learning check-convergence \\
|
|
115
|
+
--predictions predictions.jsonl \\
|
|
116
|
+
--human-labels labels.jsonl \\
|
|
117
|
+
--metric krippendorff_alpha \\
|
|
118
|
+
--threshold 0.85
|
|
119
|
+
|
|
120
|
+
$ bead active-learning check-convergence \\
|
|
121
|
+
--predictions predictions.jsonl \\
|
|
122
|
+
--human-labels labels.jsonl \\
|
|
123
|
+
--metric fleiss_kappa \\
|
|
124
|
+
--threshold 0.75
|
|
125
|
+
"""
|
|
126
|
+
try:
|
|
127
|
+
console.rule("[bold]Convergence Check[/bold]")
|
|
128
|
+
|
|
129
|
+
# Load predictions
|
|
130
|
+
print_info(f"Loading predictions from {predictions}")
|
|
131
|
+
with open(predictions, encoding="utf-8") as f:
|
|
132
|
+
pred_records = [json.loads(line) for line in f if line.strip()]
|
|
133
|
+
|
|
134
|
+
model_predictions = [r["prediction"] for r in pred_records]
|
|
135
|
+
print_success(f"Loaded {len(model_predictions)} predictions")
|
|
136
|
+
|
|
137
|
+
# Load human labels (organized by rater)
|
|
138
|
+
print_info(f"Loading human labels from {human_labels}")
|
|
139
|
+
with open(human_labels, encoding="utf-8") as f:
|
|
140
|
+
label_records = [json.loads(line) for line in f if line.strip()]
|
|
141
|
+
|
|
142
|
+
# Organize by rater
|
|
143
|
+
rater_labels: dict[str, list[int | str | float]] = {}
|
|
144
|
+
for record in label_records:
|
|
145
|
+
rater_id = str(record.get("rater_id", "rater_1"))
|
|
146
|
+
label = record["label"]
|
|
147
|
+
if rater_id not in rater_labels:
|
|
148
|
+
rater_labels[rater_id] = []
|
|
149
|
+
rater_labels[rater_id].append(label)
|
|
150
|
+
|
|
151
|
+
n_raters = len(rater_labels)
|
|
152
|
+
print_success(f"Loaded labels from {n_raters} raters")
|
|
153
|
+
|
|
154
|
+
# Create convergence detector
|
|
155
|
+
print_info(f"Computing {metric}...")
|
|
156
|
+
detector = ConvergenceDetector(
|
|
157
|
+
human_agreement_metric=metric,
|
|
158
|
+
convergence_threshold=threshold,
|
|
159
|
+
min_iterations=min_iterations,
|
|
160
|
+
statistical_test=True,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Compute human baseline
|
|
164
|
+
with Progress(
|
|
165
|
+
SpinnerColumn(),
|
|
166
|
+
TextColumn("[progress.description]{task.description}"),
|
|
167
|
+
console=console,
|
|
168
|
+
) as progress:
|
|
169
|
+
progress.add_task("Computing human agreement baseline...", total=None)
|
|
170
|
+
human_baseline = detector.compute_human_baseline(rater_labels)
|
|
171
|
+
|
|
172
|
+
print_success(f"Human baseline: {human_baseline:.4f}")
|
|
173
|
+
|
|
174
|
+
# Add model as another "rater" for comparison
|
|
175
|
+
all_raters = {**rater_labels, "model": model_predictions}
|
|
176
|
+
|
|
177
|
+
# Compute agreement including model
|
|
178
|
+
if metric == "krippendorff_alpha":
|
|
179
|
+
model_agreement = InterAnnotatorMetrics.krippendorff_alpha(
|
|
180
|
+
all_raters, metric="nominal"
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
# For other metrics, compare model directly to human majority vote
|
|
184
|
+
# Get majority human label for each item
|
|
185
|
+
n_items = len(model_predictions)
|
|
186
|
+
human_votes = []
|
|
187
|
+
for i in range(n_items):
|
|
188
|
+
votes_for_item = [rater_labels[r][i] for r in rater_labels]
|
|
189
|
+
# Simple majority vote
|
|
190
|
+
majority = max(set(votes_for_item), key=votes_for_item.count)
|
|
191
|
+
human_votes.append(majority)
|
|
192
|
+
|
|
193
|
+
# Compute agreement between model and human majority
|
|
194
|
+
model_human_pairs = zip(model_predictions, human_votes, strict=True)
|
|
195
|
+
agreements = sum(p == h for p, h in model_human_pairs)
|
|
196
|
+
model_agreement = agreements / len(model_predictions)
|
|
197
|
+
|
|
198
|
+
print_success(f"Model agreement: {model_agreement:.4f}")
|
|
199
|
+
|
|
200
|
+
# Check convergence
|
|
201
|
+
converged = detector.check_convergence(
|
|
202
|
+
model_accuracy=model_agreement, iteration=min_iterations
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Display results
|
|
206
|
+
table = Table(title="Convergence Results")
|
|
207
|
+
table.add_column("Metric", style="cyan")
|
|
208
|
+
table.add_column("Value", style="green", justify="right")
|
|
209
|
+
|
|
210
|
+
table.add_row("Agreement Metric", metric)
|
|
211
|
+
table.add_row("Human Baseline", f"{human_baseline:.4f}")
|
|
212
|
+
table.add_row("Model Agreement", f"{model_agreement:.4f}")
|
|
213
|
+
table.add_row("Threshold", f"{threshold:.4f}")
|
|
214
|
+
table.add_row("Converged", "✓ Yes" if converged else "✗ No")
|
|
215
|
+
|
|
216
|
+
if converged:
|
|
217
|
+
table.add_row(
|
|
218
|
+
"Status", "[green]Model has converged to human agreement[/green]"
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
gap = threshold - model_agreement
|
|
222
|
+
table.add_row(
|
|
223
|
+
"Status", f"[yellow]Need {gap:.4f} more to reach threshold[/yellow]"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
console.print(table)
|
|
227
|
+
|
|
228
|
+
# Exit with appropriate code
|
|
229
|
+
if converged:
|
|
230
|
+
print_success("Convergence achieved!")
|
|
231
|
+
ctx.exit(0)
|
|
232
|
+
else:
|
|
233
|
+
print_info("Not yet converged. Continue training.")
|
|
234
|
+
ctx.exit(1)
|
|
235
|
+
|
|
236
|
+
except FileNotFoundError as e:
|
|
237
|
+
print_error(f"File not found: {e}")
|
|
238
|
+
ctx.exit(1)
|
|
239
|
+
except KeyError as e:
|
|
240
|
+
print_error(f"Missing required field in data: {e}")
|
|
241
|
+
ctx.exit(1)
|
|
242
|
+
except json.JSONDecodeError as e:
|
|
243
|
+
print_error(f"Invalid JSON: {e}")
|
|
244
|
+
ctx.exit(1)
|
|
245
|
+
except Exception as e:
|
|
246
|
+
print_error(f"Convergence check failed: {e}")
|
|
247
|
+
ctx.exit(1)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@click.command()
|
|
251
|
+
@click.option(
|
|
252
|
+
"--checkpoint-dir",
|
|
253
|
+
type=click.Path(exists=True, path_type=Path),
|
|
254
|
+
required=True,
|
|
255
|
+
help="Directory containing model checkpoints",
|
|
256
|
+
)
|
|
257
|
+
@click.option(
|
|
258
|
+
"--human-labels",
|
|
259
|
+
type=click.Path(exists=True, path_type=Path),
|
|
260
|
+
required=True,
|
|
261
|
+
help="Path to human labels file (JSONL with 'label' field per rater)",
|
|
262
|
+
)
|
|
263
|
+
@click.option(
|
|
264
|
+
"--metric",
|
|
265
|
+
type=click.Choice(
|
|
266
|
+
["krippendorff_alpha", "fleiss_kappa", "cohens_kappa", "percentage_agreement"],
|
|
267
|
+
case_sensitive=False,
|
|
268
|
+
),
|
|
269
|
+
default="krippendorff_alpha",
|
|
270
|
+
help="Agreement metric to use (default: krippendorff_alpha)",
|
|
271
|
+
)
|
|
272
|
+
@click.option(
|
|
273
|
+
"--threshold",
|
|
274
|
+
type=float,
|
|
275
|
+
default=0.80,
|
|
276
|
+
help="Convergence threshold (default: 0.80)",
|
|
277
|
+
)
|
|
278
|
+
@click.option(
|
|
279
|
+
"--min-iterations",
|
|
280
|
+
type=int,
|
|
281
|
+
default=1,
|
|
282
|
+
help="Minimum iterations before checking convergence (default: 1)",
|
|
283
|
+
)
|
|
284
|
+
@click.option(
|
|
285
|
+
"--output",
|
|
286
|
+
"-o",
|
|
287
|
+
type=click.Path(path_type=Path),
|
|
288
|
+
default=None,
|
|
289
|
+
help="Output file for convergence report (default: stdout)",
|
|
290
|
+
)
|
|
291
|
+
@click.pass_context
|
|
292
|
+
def monitor_convergence(
|
|
293
|
+
ctx: click.Context,
|
|
294
|
+
checkpoint_dir: Path,
|
|
295
|
+
human_labels: Path,
|
|
296
|
+
metric: str,
|
|
297
|
+
threshold: float,
|
|
298
|
+
min_iterations: int,
|
|
299
|
+
output: Path | None,
|
|
300
|
+
) -> None:
|
|
301
|
+
r"""Monitor convergence over multiple iterations.
|
|
302
|
+
|
|
303
|
+
Loads model checkpoints from a directory and checks convergence
|
|
304
|
+
against human labels for each iteration. Produces a convergence
|
|
305
|
+
report showing progress over time.
|
|
306
|
+
|
|
307
|
+
Parameters
|
|
308
|
+
----------
|
|
309
|
+
ctx : click.Context
|
|
310
|
+
Click context object.
|
|
311
|
+
checkpoint_dir : Path
|
|
312
|
+
Directory containing model checkpoints.
|
|
313
|
+
human_labels : Path
|
|
314
|
+
Path to human labels file.
|
|
315
|
+
metric : str
|
|
316
|
+
Agreement metric name.
|
|
317
|
+
threshold : float
|
|
318
|
+
Convergence threshold.
|
|
319
|
+
min_iterations : int
|
|
320
|
+
Minimum iterations before allowing convergence.
|
|
321
|
+
output : Path | None
|
|
322
|
+
Output file path (None for stdout).
|
|
323
|
+
|
|
324
|
+
Examples
|
|
325
|
+
--------
|
|
326
|
+
$ bead active-learning monitor-convergence \\
|
|
327
|
+
--checkpoint-dir models/checkpoints \\
|
|
328
|
+
--human-labels labels.jsonl \\
|
|
329
|
+
--metric krippendorff_alpha \\
|
|
330
|
+
--threshold 0.85
|
|
331
|
+
|
|
332
|
+
$ bead active-learning monitor-convergence \\
|
|
333
|
+
--checkpoint-dir models/checkpoints \\
|
|
334
|
+
--human-labels labels.jsonl \\
|
|
335
|
+
--output convergence_report.json
|
|
336
|
+
"""
|
|
337
|
+
try:
|
|
338
|
+
console.rule("[bold]Convergence Monitoring[/bold]")
|
|
339
|
+
|
|
340
|
+
# Load human labels
|
|
341
|
+
print_info(f"Loading human labels from {human_labels}")
|
|
342
|
+
with open(human_labels, encoding="utf-8") as f:
|
|
343
|
+
label_records = [json.loads(line) for line in f if line.strip()]
|
|
344
|
+
|
|
345
|
+
# Organize by rater
|
|
346
|
+
rater_labels: dict[str, list[int | str | float]] = {}
|
|
347
|
+
for record in label_records:
|
|
348
|
+
rater_id = str(record.get("rater_id", "rater_1"))
|
|
349
|
+
label = record["label"]
|
|
350
|
+
if rater_id not in rater_labels:
|
|
351
|
+
rater_labels[rater_id] = []
|
|
352
|
+
rater_labels[rater_id].append(label)
|
|
353
|
+
|
|
354
|
+
n_raters = len(rater_labels)
|
|
355
|
+
print_success(f"Loaded labels from {n_raters} raters")
|
|
356
|
+
|
|
357
|
+
# Create convergence detector
|
|
358
|
+
detector = ConvergenceDetector(
|
|
359
|
+
human_agreement_metric=metric,
|
|
360
|
+
convergence_threshold=threshold,
|
|
361
|
+
min_iterations=min_iterations,
|
|
362
|
+
statistical_test=True,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Compute human baseline
|
|
366
|
+
with Progress(
|
|
367
|
+
SpinnerColumn(),
|
|
368
|
+
TextColumn("[progress.description]{task.description}"),
|
|
369
|
+
console=console,
|
|
370
|
+
) as progress:
|
|
371
|
+
progress.add_task("Computing human agreement baseline...", total=None)
|
|
372
|
+
human_baseline = detector.compute_human_baseline(rater_labels)
|
|
373
|
+
|
|
374
|
+
print_success(f"Human baseline: {human_baseline:.4f}")
|
|
375
|
+
|
|
376
|
+
# Find checkpoint files
|
|
377
|
+
checkpoint_files = sorted(checkpoint_dir.glob("**/predictions*.jsonl"))
|
|
378
|
+
if not checkpoint_files:
|
|
379
|
+
print_error(f"No prediction files found in {checkpoint_dir}")
|
|
380
|
+
ctx.exit(1)
|
|
381
|
+
|
|
382
|
+
print_info(f"Found {len(checkpoint_files)} checkpoint(s)")
|
|
383
|
+
|
|
384
|
+
# Process each checkpoint
|
|
385
|
+
convergence_history: list[dict[str, str | int | float | bool]] = []
|
|
386
|
+
for checkpoint_file in checkpoint_files:
|
|
387
|
+
iteration_num = _extract_iteration_number(checkpoint_file)
|
|
388
|
+
if iteration_num is None:
|
|
389
|
+
continue
|
|
390
|
+
|
|
391
|
+
print_info(f"Processing iteration {iteration_num}...")
|
|
392
|
+
|
|
393
|
+
# Load predictions
|
|
394
|
+
with open(checkpoint_file, encoding="utf-8") as f:
|
|
395
|
+
pred_records = [json.loads(line) for line in f if line.strip()]
|
|
396
|
+
|
|
397
|
+
model_predictions = [r["prediction"] for r in pred_records]
|
|
398
|
+
|
|
399
|
+
# Compute model agreement
|
|
400
|
+
all_raters = {**rater_labels, "model": model_predictions}
|
|
401
|
+
if metric == "krippendorff_alpha":
|
|
402
|
+
model_agreement = InterAnnotatorMetrics.krippendorff_alpha(
|
|
403
|
+
all_raters, metric="nominal"
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
# For other metrics, compare to human majority
|
|
407
|
+
n_items = len(model_predictions)
|
|
408
|
+
human_votes = []
|
|
409
|
+
for i in range(n_items):
|
|
410
|
+
votes_for_item = [rater_labels[r][i] for r in rater_labels]
|
|
411
|
+
majority = max(set(votes_for_item), key=votes_for_item.count)
|
|
412
|
+
human_votes.append(majority)
|
|
413
|
+
|
|
414
|
+
agreements = sum(
|
|
415
|
+
p == h for p, h in zip(model_predictions, human_votes, strict=True)
|
|
416
|
+
)
|
|
417
|
+
model_agreement = agreements / len(model_predictions)
|
|
418
|
+
|
|
419
|
+
# Check convergence
|
|
420
|
+
converged = detector.check_convergence(
|
|
421
|
+
model_accuracy=model_agreement, iteration=iteration_num
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
convergence_history.append(
|
|
425
|
+
{
|
|
426
|
+
"iteration": iteration_num,
|
|
427
|
+
"model_agreement": model_agreement,
|
|
428
|
+
"human_baseline": human_baseline,
|
|
429
|
+
"converged": converged,
|
|
430
|
+
"gap": human_baseline - model_agreement,
|
|
431
|
+
}
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Display results
|
|
435
|
+
table = Table(title="Convergence History")
|
|
436
|
+
table.add_column("Iteration", style="cyan")
|
|
437
|
+
table.add_column("Model Agreement", style="green", justify="right")
|
|
438
|
+
table.add_column("Human Baseline", style="blue", justify="right")
|
|
439
|
+
table.add_column("Gap", style="yellow", justify="right")
|
|
440
|
+
table.add_column("Status", style="magenta")
|
|
441
|
+
|
|
442
|
+
for record in convergence_history:
|
|
443
|
+
status = "✓ Converged" if record["converged"] else "✗ Not converged"
|
|
444
|
+
table.add_row(
|
|
445
|
+
str(record["iteration"]),
|
|
446
|
+
f"{record['model_agreement']:.4f}",
|
|
447
|
+
f"{record['human_baseline']:.4f}",
|
|
448
|
+
f"{record['gap']:.4f}",
|
|
449
|
+
status,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
console.print(table)
|
|
453
|
+
|
|
454
|
+
# Write output if specified
|
|
455
|
+
if output:
|
|
456
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
457
|
+
json.dump(convergence_history, f, indent=2)
|
|
458
|
+
print_success(f"Report written to {output}")
|
|
459
|
+
|
|
460
|
+
# Check if any iteration converged
|
|
461
|
+
any_converged = any(r["converged"] for r in convergence_history)
|
|
462
|
+
if any_converged:
|
|
463
|
+
print_success("Convergence achieved in at least one iteration!")
|
|
464
|
+
ctx.exit(0)
|
|
465
|
+
else:
|
|
466
|
+
print_info("No convergence detected yet. Continue training.")
|
|
467
|
+
ctx.exit(1)
|
|
468
|
+
|
|
469
|
+
except FileNotFoundError as e:
|
|
470
|
+
print_error(f"File not found: {e}")
|
|
471
|
+
ctx.exit(1)
|
|
472
|
+
except KeyError as e:
|
|
473
|
+
print_error(f"Missing required field in data: {e}")
|
|
474
|
+
ctx.exit(1)
|
|
475
|
+
except json.JSONDecodeError as e:
|
|
476
|
+
print_error(f"Invalid JSON: {e}")
|
|
477
|
+
ctx.exit(1)
|
|
478
|
+
except Exception as e:
|
|
479
|
+
print_error(f"Convergence monitoring failed: {e}")
|
|
480
|
+
ctx.exit(1)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _extract_iteration_number(path: Path) -> int | None:
|
|
484
|
+
"""Extract iteration number from checkpoint file path.
|
|
485
|
+
|
|
486
|
+
Parameters
|
|
487
|
+
----------
|
|
488
|
+
path : Path
|
|
489
|
+
Checkpoint file path.
|
|
490
|
+
|
|
491
|
+
Returns
|
|
492
|
+
-------
|
|
493
|
+
int | None
|
|
494
|
+
Iteration number if found, None otherwise.
|
|
495
|
+
"""
|
|
496
|
+
# Try to find iteration number in filename
|
|
497
|
+
match = re.search(r"iteration[_-]?(\d+)", path.stem, re.IGNORECASE)
|
|
498
|
+
if match:
|
|
499
|
+
return int(match.group(1))
|
|
500
|
+
|
|
501
|
+
# Try to find number at end of filename
|
|
502
|
+
match = re.search(r"(\d+)", path.stem)
|
|
503
|
+
if match:
|
|
504
|
+
return int(match.group(1))
|
|
505
|
+
|
|
506
|
+
return None
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
# Register commands
|
|
510
|
+
active_learning.add_command(check_convergence)
|
|
511
|
+
active_learning.add_command(monitor_convergence)
|
|
512
|
+
active_learning.add_command(run)
|
|
513
|
+
active_learning.add_command(select_items)
|