bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
bead/cli/training.py
ADDED
|
@@ -0,0 +1,1080 @@
|
|
|
1
|
+
"""Training commands for bead CLI.
|
|
2
|
+
|
|
3
|
+
This module provides commands for collecting data, training judgment prediction
|
|
4
|
+
models, and evaluating model performance (Stage 6 of the bead pipeline).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import cast
|
|
12
|
+
|
|
13
|
+
import click
|
|
14
|
+
import numpy as np
|
|
15
|
+
from rich.console import Console
|
|
16
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, track
|
|
17
|
+
from rich.table import Table
|
|
18
|
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
|
19
|
+
from sklearn.model_selection import KFold
|
|
20
|
+
|
|
21
|
+
from bead.cli.models import _import_class # type: ignore[attr-defined]
|
|
22
|
+
from bead.cli.utils import print_error, print_info, print_success
|
|
23
|
+
from bead.data.base import JsonValue
|
|
24
|
+
from bead.data.serialization import read_jsonlines
|
|
25
|
+
from bead.data_collection.jatos import JATOSDataCollector
|
|
26
|
+
from bead.evaluation.interannotator import InterAnnotatorMetrics
|
|
27
|
+
from bead.items.item import Item
|
|
28
|
+
|
|
29
|
+
console = Console()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@click.group()
|
|
33
|
+
def training() -> None:
|
|
34
|
+
r"""Training commands (Stage 6).
|
|
35
|
+
|
|
36
|
+
Commands for collecting data and training judgment prediction models.
|
|
37
|
+
|
|
38
|
+
\b
|
|
39
|
+
Examples:
|
|
40
|
+
$ bead training collect-data results.jsonl \\
|
|
41
|
+
--jatos-url https://jatos.example.com \\
|
|
42
|
+
--api-token TOKEN --study-id 123
|
|
43
|
+
$ bead training show-data-stats results.jsonl
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@click.command()
|
|
48
|
+
@click.argument("output_file", type=click.Path(path_type=Path))
|
|
49
|
+
@click.option("--jatos-url", required=True, help="JATOS server URL")
|
|
50
|
+
@click.option("--api-token", required=True, help="JATOS API token")
|
|
51
|
+
@click.option("--study-id", required=True, type=int, help="JATOS study ID")
|
|
52
|
+
@click.option("--component-id", type=int, help="Filter by component ID (optional)")
|
|
53
|
+
@click.option("--worker-type", help="Filter by worker type (optional)")
|
|
54
|
+
@click.pass_context
|
|
55
|
+
def collect_data(
|
|
56
|
+
ctx: click.Context,
|
|
57
|
+
output_file: Path,
|
|
58
|
+
jatos_url: str,
|
|
59
|
+
api_token: str,
|
|
60
|
+
study_id: int,
|
|
61
|
+
component_id: int | None,
|
|
62
|
+
worker_type: str | None,
|
|
63
|
+
) -> None:
|
|
64
|
+
r"""Collect judgment data from JATOS.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
ctx : click.Context
|
|
69
|
+
Click context object.
|
|
70
|
+
output_file : Path
|
|
71
|
+
Output path for collected data.
|
|
72
|
+
jatos_url : str
|
|
73
|
+
JATOS server URL.
|
|
74
|
+
api_token : str
|
|
75
|
+
JATOS API token.
|
|
76
|
+
study_id : int
|
|
77
|
+
JATOS study ID.
|
|
78
|
+
component_id : int | None
|
|
79
|
+
Component ID to filter by.
|
|
80
|
+
worker_type : str | None
|
|
81
|
+
Worker type to filter by.
|
|
82
|
+
|
|
83
|
+
Examples
|
|
84
|
+
--------
|
|
85
|
+
$ bead training collect-data results.jsonl \\
|
|
86
|
+
--jatos-url https://jatos.example.com \\
|
|
87
|
+
--api-token my-token \\
|
|
88
|
+
--study-id 123
|
|
89
|
+
|
|
90
|
+
$ bead training collect-data results.jsonl \\
|
|
91
|
+
--jatos-url https://jatos.example.com \\
|
|
92
|
+
--api-token my-token \\
|
|
93
|
+
--study-id 123 \\
|
|
94
|
+
--component-id 456 \\
|
|
95
|
+
--worker-type Prolific
|
|
96
|
+
"""
|
|
97
|
+
try:
|
|
98
|
+
print_info(f"Collecting data from JATOS study {study_id}")
|
|
99
|
+
|
|
100
|
+
with Progress(
|
|
101
|
+
SpinnerColumn(),
|
|
102
|
+
TextColumn("[progress.description]{task.description}"),
|
|
103
|
+
console=console,
|
|
104
|
+
) as progress:
|
|
105
|
+
progress.add_task("Downloading results from JATOS...", total=None)
|
|
106
|
+
|
|
107
|
+
collector = JATOSDataCollector(
|
|
108
|
+
base_url=jatos_url,
|
|
109
|
+
api_token=api_token,
|
|
110
|
+
study_id=study_id,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
results = collector.download_results(
|
|
114
|
+
output_path=output_file,
|
|
115
|
+
component_id=component_id,
|
|
116
|
+
worker_type=worker_type,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
print_success(f"Collected {len(results)} results: {output_file}")
|
|
120
|
+
|
|
121
|
+
except Exception as e:
|
|
122
|
+
print_error(f"Failed to collect data: {e}")
|
|
123
|
+
ctx.exit(1)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@click.command()
|
|
127
|
+
@click.argument("data_file", type=click.Path(exists=True, path_type=Path))
|
|
128
|
+
@click.pass_context
|
|
129
|
+
def show_data_stats(ctx: click.Context, data_file: Path) -> None:
|
|
130
|
+
"""Show statistics about collected data.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
ctx : click.Context
|
|
135
|
+
Click context object.
|
|
136
|
+
data_file : Path
|
|
137
|
+
Path to data file.
|
|
138
|
+
|
|
139
|
+
Examples
|
|
140
|
+
--------
|
|
141
|
+
$ bead training show-data-stats results.jsonl
|
|
142
|
+
"""
|
|
143
|
+
try:
|
|
144
|
+
print_info(f"Analyzing data: {data_file}")
|
|
145
|
+
|
|
146
|
+
# Load and analyze data
|
|
147
|
+
results: list[dict[str, JsonValue]] = []
|
|
148
|
+
with open(data_file, encoding="utf-8") as f:
|
|
149
|
+
for line in f:
|
|
150
|
+
line = line.strip()
|
|
151
|
+
if not line:
|
|
152
|
+
continue
|
|
153
|
+
result: dict[str, JsonValue] = json.loads(line)
|
|
154
|
+
results.append(result)
|
|
155
|
+
|
|
156
|
+
if not results:
|
|
157
|
+
print_error("No data found in file")
|
|
158
|
+
ctx.exit(1)
|
|
159
|
+
|
|
160
|
+
# Calculate statistics
|
|
161
|
+
total_results = len(results)
|
|
162
|
+
|
|
163
|
+
# Count unique workers if available
|
|
164
|
+
worker_ids: set[str] = set()
|
|
165
|
+
for result in results:
|
|
166
|
+
if "worker_id" in result and isinstance(result["worker_id"], str):
|
|
167
|
+
worker_ids.add(result["worker_id"])
|
|
168
|
+
|
|
169
|
+
# Count response types if available
|
|
170
|
+
response_types: dict[str, int] = {}
|
|
171
|
+
for result in results:
|
|
172
|
+
if "data" in result:
|
|
173
|
+
data: JsonValue = result["data"]
|
|
174
|
+
if isinstance(data, dict):
|
|
175
|
+
for key in data.keys(): # type: ignore[var-annotated]
|
|
176
|
+
key_str = str(key) # type: ignore[arg-type]
|
|
177
|
+
response_types[key_str] = response_types.get(key_str, 0) + 1
|
|
178
|
+
|
|
179
|
+
# Display statistics
|
|
180
|
+
table = Table(title="Data Statistics")
|
|
181
|
+
table.add_column("Metric", style="cyan")
|
|
182
|
+
table.add_column("Value", style="green", justify="right")
|
|
183
|
+
|
|
184
|
+
table.add_row("Total Results", str(total_results))
|
|
185
|
+
if worker_ids:
|
|
186
|
+
table.add_row("Unique Workers", str(len(worker_ids)))
|
|
187
|
+
|
|
188
|
+
if response_types:
|
|
189
|
+
table.add_row("", "") # Separator
|
|
190
|
+
for resp_type, count in sorted(response_types.items()):
|
|
191
|
+
table.add_row(f"Response Type: {resp_type}", str(count))
|
|
192
|
+
|
|
193
|
+
console.print(table)
|
|
194
|
+
|
|
195
|
+
except json.JSONDecodeError as e:
|
|
196
|
+
print_error(f"Invalid JSON in data file: {e}")
|
|
197
|
+
ctx.exit(1)
|
|
198
|
+
except Exception as e:
|
|
199
|
+
print_error(f"Failed to show statistics: {e}")
|
|
200
|
+
ctx.exit(1)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@click.command()
|
|
204
|
+
@click.option(
|
|
205
|
+
"--model-dir",
|
|
206
|
+
type=click.Path(exists=True, file_okay=False, path_type=Path),
|
|
207
|
+
required=True,
|
|
208
|
+
help="Directory containing trained model",
|
|
209
|
+
)
|
|
210
|
+
@click.option(
|
|
211
|
+
"--test-items",
|
|
212
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
213
|
+
required=True,
|
|
214
|
+
help="Path to test items (JSONL)",
|
|
215
|
+
)
|
|
216
|
+
@click.option(
|
|
217
|
+
"--test-labels",
|
|
218
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
219
|
+
required=True,
|
|
220
|
+
help="Path to test labels (JSONL, one label per line)",
|
|
221
|
+
)
|
|
222
|
+
@click.option(
|
|
223
|
+
"--participant-ids",
|
|
224
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
225
|
+
help="Path to participant IDs (JSONL, one ID per line, optional)",
|
|
226
|
+
)
|
|
227
|
+
@click.option(
|
|
228
|
+
"--metrics",
|
|
229
|
+
default="accuracy,precision,recall,f1",
|
|
230
|
+
help="Comma-separated list of metrics (accuracy,precision,recall,f1)",
|
|
231
|
+
)
|
|
232
|
+
@click.option(
|
|
233
|
+
"--average",
|
|
234
|
+
type=click.Choice(["macro", "micro", "weighted"]),
|
|
235
|
+
default="macro",
|
|
236
|
+
help="Averaging strategy for multi-class metrics",
|
|
237
|
+
)
|
|
238
|
+
@click.option(
|
|
239
|
+
"--output",
|
|
240
|
+
type=click.Path(path_type=Path),
|
|
241
|
+
help="Output path for evaluation report (JSON)",
|
|
242
|
+
)
|
|
243
|
+
@click.pass_context
|
|
244
|
+
def evaluate(
|
|
245
|
+
ctx: click.Context,
|
|
246
|
+
model_dir: Path,
|
|
247
|
+
test_items: Path,
|
|
248
|
+
test_labels: Path,
|
|
249
|
+
participant_ids: Path | None,
|
|
250
|
+
metrics: str,
|
|
251
|
+
average: str,
|
|
252
|
+
output: Path | None,
|
|
253
|
+
) -> None:
|
|
254
|
+
r"""Evaluate trained model on test set.
|
|
255
|
+
|
|
256
|
+
Loads a trained model and computes evaluation metrics (accuracy, precision,
|
|
257
|
+
recall, F1) on a held-out test set.
|
|
258
|
+
|
|
259
|
+
Parameters
|
|
260
|
+
----------
|
|
261
|
+
ctx : click.Context
|
|
262
|
+
Click context object.
|
|
263
|
+
model_dir : Path
|
|
264
|
+
Directory containing trained model.
|
|
265
|
+
test_items : Path
|
|
266
|
+
Path to test items (JSONL).
|
|
267
|
+
test_labels : Path
|
|
268
|
+
Path to test labels (JSONL, one label per line).
|
|
269
|
+
participant_ids : Path | None
|
|
270
|
+
Path to participant IDs (optional, for random effects models).
|
|
271
|
+
metrics : str
|
|
272
|
+
Comma-separated list of metrics to compute.
|
|
273
|
+
average : str
|
|
274
|
+
Averaging strategy for multi-class metrics.
|
|
275
|
+
output : Path | None
|
|
276
|
+
Output path for evaluation report (JSON).
|
|
277
|
+
|
|
278
|
+
Examples
|
|
279
|
+
--------
|
|
280
|
+
$ bead training evaluate \\
|
|
281
|
+
--model-dir models/my_model/ \\
|
|
282
|
+
--test-items data/test_items.jsonl \\
|
|
283
|
+
--test-labels data/test_labels.jsonl \\
|
|
284
|
+
--metrics accuracy,f1 \\
|
|
285
|
+
--output evaluation_report.json
|
|
286
|
+
"""
|
|
287
|
+
try:
|
|
288
|
+
print_info(f"Evaluating model: {model_dir}")
|
|
289
|
+
|
|
290
|
+
# Load model config
|
|
291
|
+
config_path = model_dir / "config.json"
|
|
292
|
+
if not config_path.exists():
|
|
293
|
+
print_error(f"Model config not found: {config_path}")
|
|
294
|
+
ctx.exit(1)
|
|
295
|
+
|
|
296
|
+
with open(config_path, encoding="utf-8") as f:
|
|
297
|
+
model_config = json.load(f)
|
|
298
|
+
|
|
299
|
+
task_type = model_config.get("task_type")
|
|
300
|
+
if not task_type:
|
|
301
|
+
print_error("Model config missing 'task_type' field")
|
|
302
|
+
ctx.exit(1)
|
|
303
|
+
|
|
304
|
+
# Load test items
|
|
305
|
+
items_list = read_jsonlines(test_items, Item)
|
|
306
|
+
print_info(f"Loaded {len(items_list)} test items")
|
|
307
|
+
|
|
308
|
+
# Load test labels
|
|
309
|
+
with open(test_labels, encoding="utf-8") as f:
|
|
310
|
+
labels: list[str | int | float] = [
|
|
311
|
+
json.loads(line.strip()) for line in f if line.strip()
|
|
312
|
+
]
|
|
313
|
+
|
|
314
|
+
if len(items_list) != len(labels):
|
|
315
|
+
print_error(f"Mismatch: {len(items_list)} items but {len(labels)} labels")
|
|
316
|
+
ctx.exit(1)
|
|
317
|
+
|
|
318
|
+
# Load participant IDs if provided
|
|
319
|
+
participant_ids_list: list[str] | None = None
|
|
320
|
+
if participant_ids:
|
|
321
|
+
with open(participant_ids, encoding="utf-8") as f:
|
|
322
|
+
participant_ids_list = [
|
|
323
|
+
json.loads(line.strip()) for line in f if line.strip()
|
|
324
|
+
]
|
|
325
|
+
if len(participant_ids_list) != len(items_list):
|
|
326
|
+
print_error(
|
|
327
|
+
f"Mismatch: {len(items_list)} items "
|
|
328
|
+
f"but {len(participant_ids_list)} participant IDs"
|
|
329
|
+
)
|
|
330
|
+
ctx.exit(1)
|
|
331
|
+
|
|
332
|
+
# Load model
|
|
333
|
+
model_class_name = f"{task_type.title().replace('_', '')}Model"
|
|
334
|
+
model_module = f"bead.active_learning.models.{task_type}"
|
|
335
|
+
model_class = _import_class(f"{model_module}.{model_class_name}")
|
|
336
|
+
|
|
337
|
+
model_instance = model_class.load(model_dir)
|
|
338
|
+
print_success(f"Loaded model from {model_dir}")
|
|
339
|
+
|
|
340
|
+
# Make predictions
|
|
341
|
+
with Progress(
|
|
342
|
+
SpinnerColumn(),
|
|
343
|
+
TextColumn("[progress.description]{task.description}"),
|
|
344
|
+
console=console,
|
|
345
|
+
) as progress:
|
|
346
|
+
progress.add_task("Making predictions...", total=None)
|
|
347
|
+
predictions = model_instance.predict(items_list, participant_ids_list)
|
|
348
|
+
|
|
349
|
+
# Compute requested metrics
|
|
350
|
+
metrics_list = [m.strip().lower() for m in metrics.split(",")]
|
|
351
|
+
results: dict[str, float] = {}
|
|
352
|
+
|
|
353
|
+
for metric_name in metrics_list:
|
|
354
|
+
if metric_name == "accuracy":
|
|
355
|
+
acc = accuracy_score(labels, predictions)
|
|
356
|
+
results["accuracy"] = acc
|
|
357
|
+
elif metric_name in ["precision", "recall", "f1"]:
|
|
358
|
+
precision, recall, f1, support = precision_recall_fscore_support(
|
|
359
|
+
labels, predictions, average=average, zero_division=0.0
|
|
360
|
+
)
|
|
361
|
+
if "precision" not in results:
|
|
362
|
+
results["precision"] = float(precision)
|
|
363
|
+
results["recall"] = float(recall)
|
|
364
|
+
results["f1"] = float(f1)
|
|
365
|
+
# support is None when using averaging
|
|
366
|
+
if support is not None:
|
|
367
|
+
results["support"] = (
|
|
368
|
+
float(support)
|
|
369
|
+
if isinstance(support, int | float)
|
|
370
|
+
else float(sum(support))
|
|
371
|
+
)
|
|
372
|
+
else:
|
|
373
|
+
print_error(f"Unknown metric: {metric_name}")
|
|
374
|
+
ctx.exit(1)
|
|
375
|
+
|
|
376
|
+
# Display results
|
|
377
|
+
table = Table(title="Evaluation Results")
|
|
378
|
+
table.add_column("Metric", style="cyan")
|
|
379
|
+
table.add_column("Value", style="green", justify="right")
|
|
380
|
+
|
|
381
|
+
for metric_name, value in results.items():
|
|
382
|
+
if metric_name == "support":
|
|
383
|
+
table.add_row(metric_name.capitalize(), f"{int(value)}")
|
|
384
|
+
else:
|
|
385
|
+
table.add_row(metric_name.capitalize(), f"{value:.4f}")
|
|
386
|
+
|
|
387
|
+
console.print(table)
|
|
388
|
+
|
|
389
|
+
# Save to file if requested
|
|
390
|
+
if output:
|
|
391
|
+
output.parent.mkdir(parents=True, exist_ok=True)
|
|
392
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
393
|
+
json.dump(
|
|
394
|
+
{
|
|
395
|
+
"model_dir": str(model_dir),
|
|
396
|
+
"test_items": str(test_items),
|
|
397
|
+
"test_labels": str(test_labels),
|
|
398
|
+
"metrics": results,
|
|
399
|
+
"average": average,
|
|
400
|
+
},
|
|
401
|
+
f,
|
|
402
|
+
indent=2,
|
|
403
|
+
)
|
|
404
|
+
print_success(f"Evaluation report saved: {output}")
|
|
405
|
+
|
|
406
|
+
except FileNotFoundError as e:
|
|
407
|
+
print_error(f"File not found: {e}")
|
|
408
|
+
ctx.exit(1)
|
|
409
|
+
except json.JSONDecodeError as e:
|
|
410
|
+
print_error(f"Invalid JSON: {e}")
|
|
411
|
+
ctx.exit(1)
|
|
412
|
+
except ValueError as e:
|
|
413
|
+
print_error(f"Validation error: {e}")
|
|
414
|
+
ctx.exit(1)
|
|
415
|
+
except ImportError as e:
|
|
416
|
+
print_error(f"Failed to import model class: {e}")
|
|
417
|
+
ctx.exit(1)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@click.command()
|
|
421
|
+
@click.option(
|
|
422
|
+
"--items",
|
|
423
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
424
|
+
required=True,
|
|
425
|
+
help="Path to items (JSONL)",
|
|
426
|
+
)
|
|
427
|
+
@click.option(
|
|
428
|
+
"--labels",
|
|
429
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
430
|
+
required=True,
|
|
431
|
+
help="Path to labels (JSONL, one label per line)",
|
|
432
|
+
)
|
|
433
|
+
@click.option(
|
|
434
|
+
"--participant-ids",
|
|
435
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
436
|
+
help="Path to participant IDs (JSONL, optional)",
|
|
437
|
+
)
|
|
438
|
+
@click.option(
|
|
439
|
+
"--model-config",
|
|
440
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
441
|
+
required=True,
|
|
442
|
+
help="Path to model configuration (JSON/YAML)",
|
|
443
|
+
)
|
|
444
|
+
@click.option(
|
|
445
|
+
"--k-folds",
|
|
446
|
+
type=int,
|
|
447
|
+
default=5,
|
|
448
|
+
help="Number of folds for cross-validation",
|
|
449
|
+
)
|
|
450
|
+
@click.option(
|
|
451
|
+
"--stratify-by",
|
|
452
|
+
type=click.Choice(["participant_id", "label", "none"]),
|
|
453
|
+
default="none",
|
|
454
|
+
help="Stratification strategy",
|
|
455
|
+
)
|
|
456
|
+
@click.option(
|
|
457
|
+
"--random-seed",
|
|
458
|
+
type=int,
|
|
459
|
+
help="Random seed for reproducibility",
|
|
460
|
+
)
|
|
461
|
+
@click.option(
|
|
462
|
+
"--output",
|
|
463
|
+
type=click.Path(path_type=Path),
|
|
464
|
+
help="Output path for CV results (JSON)",
|
|
465
|
+
)
|
|
466
|
+
@click.pass_context
|
|
467
|
+
def cross_validate(
|
|
468
|
+
ctx: click.Context,
|
|
469
|
+
items: Path,
|
|
470
|
+
labels: Path,
|
|
471
|
+
participant_ids: Path | None,
|
|
472
|
+
model_config: Path,
|
|
473
|
+
k_folds: int,
|
|
474
|
+
stratify_by: str,
|
|
475
|
+
random_seed: int | None,
|
|
476
|
+
output: Path | None,
|
|
477
|
+
) -> None:
|
|
478
|
+
r"""Perform K-fold cross-validation.
|
|
479
|
+
|
|
480
|
+
Trains model with K-fold cross-validation and reports metrics for each fold.
|
|
481
|
+
|
|
482
|
+
Parameters
|
|
483
|
+
----------
|
|
484
|
+
ctx : click.Context
|
|
485
|
+
Click context object.
|
|
486
|
+
items : Path
|
|
487
|
+
Path to items (JSONL).
|
|
488
|
+
labels : Path
|
|
489
|
+
Path to labels (JSONL).
|
|
490
|
+
participant_ids : Path | None
|
|
491
|
+
Path to participant IDs (optional).
|
|
492
|
+
model_config : Path
|
|
493
|
+
Path to model configuration file.
|
|
494
|
+
k_folds : int
|
|
495
|
+
Number of folds.
|
|
496
|
+
stratify_by : str
|
|
497
|
+
Stratification strategy.
|
|
498
|
+
random_seed : int | None
|
|
499
|
+
Random seed for reproducibility.
|
|
500
|
+
output : Path | None
|
|
501
|
+
Output path for results (JSON).
|
|
502
|
+
|
|
503
|
+
Examples
|
|
504
|
+
--------
|
|
505
|
+
$ bead training cross-validate \\
|
|
506
|
+
--items data/items.jsonl \\
|
|
507
|
+
--labels data/labels.jsonl \\
|
|
508
|
+
--model-config config.yaml \\
|
|
509
|
+
--k-folds 5 \\
|
|
510
|
+
--stratify-by label \\
|
|
511
|
+
--output cv_results.json
|
|
512
|
+
"""
|
|
513
|
+
try:
|
|
514
|
+
print_info(f"Running {k_folds}-fold cross-validation")
|
|
515
|
+
|
|
516
|
+
# Load items
|
|
517
|
+
items_list = read_jsonlines(items, Item)
|
|
518
|
+
print_info(f"Loaded {len(items_list)} items")
|
|
519
|
+
|
|
520
|
+
# Load labels
|
|
521
|
+
with open(labels, encoding="utf-8") as f:
|
|
522
|
+
labels_list: list[JsonValue] = [
|
|
523
|
+
json.loads(line.strip()) for line in f if line.strip()
|
|
524
|
+
]
|
|
525
|
+
|
|
526
|
+
if len(items_list) != len(labels_list):
|
|
527
|
+
print_error(
|
|
528
|
+
f"Mismatch: {len(items_list)} items but {len(labels_list)} labels"
|
|
529
|
+
)
|
|
530
|
+
ctx.exit(1)
|
|
531
|
+
|
|
532
|
+
# Load participant IDs if provided
|
|
533
|
+
participant_ids_list: list[str] | None = None
|
|
534
|
+
if participant_ids:
|
|
535
|
+
with open(participant_ids, encoding="utf-8") as f:
|
|
536
|
+
participant_ids_list = [
|
|
537
|
+
json.loads(line.strip()) for line in f if line.strip()
|
|
538
|
+
]
|
|
539
|
+
if len(participant_ids_list) != len(items_list):
|
|
540
|
+
print_error(
|
|
541
|
+
f"Mismatch: {len(items_list)} items "
|
|
542
|
+
f"but {len(participant_ids_list)} participant IDs"
|
|
543
|
+
)
|
|
544
|
+
ctx.exit(1)
|
|
545
|
+
|
|
546
|
+
# Load model config
|
|
547
|
+
with open(model_config, encoding="utf-8") as f:
|
|
548
|
+
config_dict = json.load(f)
|
|
549
|
+
|
|
550
|
+
task_type = config_dict.get("task_type")
|
|
551
|
+
if not task_type:
|
|
552
|
+
print_error("Model config missing 'task_type' field")
|
|
553
|
+
ctx.exit(1)
|
|
554
|
+
|
|
555
|
+
# Import model and config classes
|
|
556
|
+
model_class_name = f"{task_type.title().replace('_', '')}Model"
|
|
557
|
+
config_class_name = f"{task_type.title().replace('_', '')}ModelConfig"
|
|
558
|
+
model_module = f"bead.active_learning.models.{task_type}"
|
|
559
|
+
config_module = "bead.config.active_learning"
|
|
560
|
+
|
|
561
|
+
model_class = _import_class(f"{model_module}.{model_class_name}")
|
|
562
|
+
config_class = _import_class(f"{config_module}.{config_class_name}")
|
|
563
|
+
|
|
564
|
+
# Create cross-validator
|
|
565
|
+
cv = KFold(n_splits=k_folds, shuffle=True, random_state=random_seed)
|
|
566
|
+
|
|
567
|
+
# Generate fold indices
|
|
568
|
+
fold_indices = list(cv.split(items_list))
|
|
569
|
+
|
|
570
|
+
print_info(f"Generated {len(fold_indices)} folds")
|
|
571
|
+
|
|
572
|
+
# Train and evaluate on each fold
|
|
573
|
+
fold_results: list[dict[str, float | int]] = []
|
|
574
|
+
|
|
575
|
+
for fold_idx, (train_indices, test_indices) in enumerate(fold_indices, start=1):
|
|
576
|
+
print_info(f"\n[Fold {fold_idx}/{k_folds}]")
|
|
577
|
+
print_info(f" Train: {len(train_indices)} items")
|
|
578
|
+
print_info(f" Test: {len(test_indices)} items")
|
|
579
|
+
|
|
580
|
+
# Get items for train and test sets
|
|
581
|
+
train_items = [items_list[i] for i in train_indices]
|
|
582
|
+
test_items = [items_list[i] for i in test_indices]
|
|
583
|
+
|
|
584
|
+
# Get labels for this fold
|
|
585
|
+
train_labels = [labels_list[i] for i in train_indices]
|
|
586
|
+
test_labels = [labels_list[i] for i in test_indices]
|
|
587
|
+
|
|
588
|
+
# Get participant IDs for this fold (if provided)
|
|
589
|
+
train_pids: list[str] | None = None
|
|
590
|
+
test_pids: list[str] | None = None
|
|
591
|
+
if participant_ids_list is not None:
|
|
592
|
+
train_pids = [participant_ids_list[i] for i in train_indices]
|
|
593
|
+
test_pids = [participant_ids_list[i] for i in test_indices]
|
|
594
|
+
|
|
595
|
+
# Create and train model for this fold
|
|
596
|
+
print_info(" Training model...")
|
|
597
|
+
model_config_obj = config_class(**config_dict)
|
|
598
|
+
model_instance = model_class(config=model_config_obj)
|
|
599
|
+
model_instance.train(train_items, train_labels, participant_ids=train_pids)
|
|
600
|
+
|
|
601
|
+
# Make predictions on test set
|
|
602
|
+
predictions = model_instance.predict(test_items, participant_ids=test_pids)
|
|
603
|
+
pred_labels = [p.predicted_class for p in predictions]
|
|
604
|
+
|
|
605
|
+
# Compute metrics
|
|
606
|
+
accuracy = accuracy_score(test_labels, pred_labels)
|
|
607
|
+
precision, recall, f1, support = precision_recall_fscore_support(
|
|
608
|
+
test_labels, pred_labels, average="macro", zero_division=0.0
|
|
609
|
+
)
|
|
610
|
+
prf: dict[str, float] = {
|
|
611
|
+
"precision": float(precision),
|
|
612
|
+
"recall": float(recall),
|
|
613
|
+
"f1": float(f1),
|
|
614
|
+
}
|
|
615
|
+
# support is None when using averaging
|
|
616
|
+
if support is not None:
|
|
617
|
+
prf["support"] = (
|
|
618
|
+
float(support)
|
|
619
|
+
if isinstance(support, int | float)
|
|
620
|
+
else float(sum(support))
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
fold_result: dict[str, float | int] = {
|
|
624
|
+
"fold": fold_idx,
|
|
625
|
+
"accuracy": float(accuracy),
|
|
626
|
+
"precision": prf["precision"],
|
|
627
|
+
"recall": prf["recall"],
|
|
628
|
+
"f1": prf["f1"],
|
|
629
|
+
}
|
|
630
|
+
if "support" in prf:
|
|
631
|
+
fold_result["support"] = prf["support"]
|
|
632
|
+
fold_results.append(fold_result)
|
|
633
|
+
|
|
634
|
+
print_success(f" Accuracy: {accuracy:.4f}, F1: {prf['f1']:.4f}")
|
|
635
|
+
|
|
636
|
+
# Compute average metrics
|
|
637
|
+
avg_results = {
|
|
638
|
+
"accuracy": np.mean([r["accuracy"] for r in fold_results]),
|
|
639
|
+
"precision": np.mean([r["precision"] for r in fold_results]),
|
|
640
|
+
"recall": np.mean([r["recall"] for r in fold_results]),
|
|
641
|
+
"f1": np.mean([r["f1"] for r in fold_results]),
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
# Display summary
|
|
645
|
+
console.rule("[bold]Cross-Validation Summary[/bold]")
|
|
646
|
+
table = Table()
|
|
647
|
+
table.add_column("Metric", style="cyan")
|
|
648
|
+
table.add_column("Mean", style="green", justify="right")
|
|
649
|
+
table.add_column("Std", style="yellow", justify="right")
|
|
650
|
+
|
|
651
|
+
for metric_name in ["accuracy", "precision", "recall", "f1"]:
|
|
652
|
+
values = [r[metric_name] for r in fold_results]
|
|
653
|
+
mean_val = np.mean(values)
|
|
654
|
+
std_val = np.std(values)
|
|
655
|
+
table.add_row(metric_name.capitalize(), f"{mean_val:.4f}", f"{std_val:.4f}")
|
|
656
|
+
|
|
657
|
+
console.print(table)
|
|
658
|
+
|
|
659
|
+
# Save results
|
|
660
|
+
if output:
|
|
661
|
+
output.parent.mkdir(parents=True, exist_ok=True)
|
|
662
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
663
|
+
json.dump(
|
|
664
|
+
{
|
|
665
|
+
"k_folds": k_folds,
|
|
666
|
+
"stratify_by": stratify_by,
|
|
667
|
+
"fold_results": fold_results,
|
|
668
|
+
"average_metrics": avg_results,
|
|
669
|
+
},
|
|
670
|
+
f,
|
|
671
|
+
indent=2,
|
|
672
|
+
)
|
|
673
|
+
print_success(f"CV results saved: {output}")
|
|
674
|
+
|
|
675
|
+
except FileNotFoundError as e:
|
|
676
|
+
print_error(f"File not found: {e}")
|
|
677
|
+
ctx.exit(1)
|
|
678
|
+
except json.JSONDecodeError as e:
|
|
679
|
+
print_error(f"Invalid JSON: {e}")
|
|
680
|
+
ctx.exit(1)
|
|
681
|
+
except ValueError as e:
|
|
682
|
+
print_error(f"Validation error: {e}")
|
|
683
|
+
ctx.exit(1)
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
@click.command()
|
|
687
|
+
@click.option(
|
|
688
|
+
"--items",
|
|
689
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
690
|
+
required=True,
|
|
691
|
+
help="Path to items (JSONL)",
|
|
692
|
+
)
|
|
693
|
+
@click.option(
|
|
694
|
+
"--labels",
|
|
695
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
696
|
+
required=True,
|
|
697
|
+
help="Path to labels (JSONL)",
|
|
698
|
+
)
|
|
699
|
+
@click.option(
|
|
700
|
+
"--model-config",
|
|
701
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
702
|
+
required=True,
|
|
703
|
+
help="Path to model configuration",
|
|
704
|
+
)
|
|
705
|
+
@click.option(
|
|
706
|
+
"--train-sizes",
|
|
707
|
+
default="0.1,0.2,0.5,0.8,1.0",
|
|
708
|
+
help="Comma-separated training set sizes (fractions)",
|
|
709
|
+
)
|
|
710
|
+
@click.option(
|
|
711
|
+
"--random-seed",
|
|
712
|
+
type=int,
|
|
713
|
+
help="Random seed for reproducibility",
|
|
714
|
+
)
|
|
715
|
+
@click.option(
|
|
716
|
+
"--output",
|
|
717
|
+
type=click.Path(path_type=Path),
|
|
718
|
+
help="Output path for learning curve data (JSON)",
|
|
719
|
+
)
|
|
720
|
+
@click.pass_context
|
|
721
|
+
def learning_curve(
|
|
722
|
+
ctx: click.Context,
|
|
723
|
+
items: Path,
|
|
724
|
+
labels: Path,
|
|
725
|
+
model_config: Path,
|
|
726
|
+
train_sizes: str,
|
|
727
|
+
random_seed: int | None,
|
|
728
|
+
output: Path | None,
|
|
729
|
+
) -> None:
|
|
730
|
+
r"""Generate learning curve with varying training set sizes.
|
|
731
|
+
|
|
732
|
+
Trains models with increasing amounts of training data and plots
|
|
733
|
+
training/validation performance.
|
|
734
|
+
|
|
735
|
+
Parameters
|
|
736
|
+
----------
|
|
737
|
+
ctx : click.Context
|
|
738
|
+
Click context object.
|
|
739
|
+
items : Path
|
|
740
|
+
Path to items (JSONL).
|
|
741
|
+
labels : Path
|
|
742
|
+
Path to labels (JSONL).
|
|
743
|
+
model_config : Path
|
|
744
|
+
Path to model configuration.
|
|
745
|
+
train_sizes : str
|
|
746
|
+
Comma-separated training set sizes (fractions).
|
|
747
|
+
random_seed : int | None
|
|
748
|
+
Random seed for reproducibility.
|
|
749
|
+
output : Path | None
|
|
750
|
+
Output path for results (JSON).
|
|
751
|
+
|
|
752
|
+
Examples
|
|
753
|
+
--------
|
|
754
|
+
$ bead training learning-curve \\
|
|
755
|
+
--items data/items.jsonl \\
|
|
756
|
+
--labels data/labels.jsonl \\
|
|
757
|
+
--model-config config.yaml \\
|
|
758
|
+
--train-sizes 0.1,0.2,0.5,1.0 \\
|
|
759
|
+
--output learning_curve.json
|
|
760
|
+
"""
|
|
761
|
+
try:
|
|
762
|
+
print_info("Generating learning curve")
|
|
763
|
+
|
|
764
|
+
# Load items
|
|
765
|
+
items_list = read_jsonlines(items, Item)
|
|
766
|
+
print_info(f"Loaded {len(items_list)} items")
|
|
767
|
+
|
|
768
|
+
# Load labels
|
|
769
|
+
with open(labels, encoding="utf-8") as f:
|
|
770
|
+
labels_list: list[str | int | float] = [
|
|
771
|
+
json.loads(line.strip()) for line in f if line.strip()
|
|
772
|
+
]
|
|
773
|
+
|
|
774
|
+
# Load model config
|
|
775
|
+
with open(model_config, encoding="utf-8") as f:
|
|
776
|
+
config_dict = json.load(f)
|
|
777
|
+
|
|
778
|
+
task_type = config_dict.get("task_type")
|
|
779
|
+
if not task_type:
|
|
780
|
+
print_error("Model config missing 'task_type' field")
|
|
781
|
+
ctx.exit(1)
|
|
782
|
+
|
|
783
|
+
# Import model and config classes
|
|
784
|
+
model_class_name = f"{task_type.title().replace('_', '')}Model"
|
|
785
|
+
config_class_name = f"{task_type.title().replace('_', '')}ModelConfig"
|
|
786
|
+
model_module = f"bead.active_learning.models.{task_type}"
|
|
787
|
+
config_module = "bead.config.active_learning"
|
|
788
|
+
|
|
789
|
+
model_class = _import_class(f"{model_module}.{model_class_name}")
|
|
790
|
+
config_class = _import_class(f"{config_module}.{config_class_name}")
|
|
791
|
+
|
|
792
|
+
# Parse train sizes
|
|
793
|
+
sizes = [float(s.strip()) for s in train_sizes.split(",")]
|
|
794
|
+
if any(s <= 0 or s > 1 for s in sizes):
|
|
795
|
+
print_error("Train sizes must be in range (0, 1]")
|
|
796
|
+
ctx.exit(1)
|
|
797
|
+
|
|
798
|
+
# Train with different data sizes
|
|
799
|
+
curve_results: list[dict[str, float]] = []
|
|
800
|
+
|
|
801
|
+
for size in track(sizes, description="Training with varying data sizes"):
|
|
802
|
+
n_samples = int(len(items_list) * size)
|
|
803
|
+
print_info(f"\nTraining with {n_samples} samples ({size:.0%})")
|
|
804
|
+
|
|
805
|
+
# Split into train/test (80/20)
|
|
806
|
+
split_idx = int(n_samples * 0.8)
|
|
807
|
+
train_items_subset = items_list[:split_idx]
|
|
808
|
+
test_items_subset = items_list[split_idx:n_samples]
|
|
809
|
+
train_labels_subset = labels_list[:split_idx]
|
|
810
|
+
test_labels_subset = labels_list[split_idx:n_samples]
|
|
811
|
+
|
|
812
|
+
# Train model
|
|
813
|
+
print_info(" Training...")
|
|
814
|
+
model_config_obj = config_class(**config_dict)
|
|
815
|
+
model_instance = model_class(config=model_config_obj)
|
|
816
|
+
# Note: participant_ids=None for fixed effects models
|
|
817
|
+
model_instance.train(
|
|
818
|
+
train_items_subset, train_labels_subset, participant_ids=None
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
# Make predictions
|
|
822
|
+
train_predictions = model_instance.predict(
|
|
823
|
+
train_items_subset, participant_ids=None
|
|
824
|
+
)
|
|
825
|
+
test_predictions = model_instance.predict(
|
|
826
|
+
test_items_subset, participant_ids=None
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
# Compute metrics
|
|
830
|
+
train_acc = accuracy_score(train_labels_subset, train_predictions)
|
|
831
|
+
test_acc = accuracy_score(test_labels_subset, test_predictions)
|
|
832
|
+
|
|
833
|
+
curve_results.append(
|
|
834
|
+
{
|
|
835
|
+
"train_size": size,
|
|
836
|
+
"n_samples": n_samples,
|
|
837
|
+
"train_accuracy": train_acc,
|
|
838
|
+
"test_accuracy": test_acc,
|
|
839
|
+
}
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
print_success(f" Train acc: {train_acc:.4f}, Test acc: {test_acc:.4f}")
|
|
843
|
+
|
|
844
|
+
# Display summary
|
|
845
|
+
console.rule("[bold]Learning Curve Summary[/bold]")
|
|
846
|
+
table = Table()
|
|
847
|
+
table.add_column("Train Size", style="cyan")
|
|
848
|
+
table.add_column("N Samples", style="blue", justify="right")
|
|
849
|
+
table.add_column("Train Acc", style="green", justify="right")
|
|
850
|
+
table.add_column("Test Acc", style="yellow", justify="right")
|
|
851
|
+
|
|
852
|
+
for result in curve_results:
|
|
853
|
+
table.add_row(
|
|
854
|
+
f"{result['train_size']:.0%}",
|
|
855
|
+
str(result["n_samples"]),
|
|
856
|
+
f"{result['train_accuracy']:.4f}",
|
|
857
|
+
f"{result['test_accuracy']:.4f}",
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
console.print(table)
|
|
861
|
+
|
|
862
|
+
# Save results
|
|
863
|
+
if output:
|
|
864
|
+
output.parent.mkdir(parents=True, exist_ok=True)
|
|
865
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
866
|
+
json.dump({"curve_data": curve_results}, f, indent=2)
|
|
867
|
+
print_success(f"Learning curve data saved: {output}")
|
|
868
|
+
|
|
869
|
+
except FileNotFoundError as e:
|
|
870
|
+
print_error(f"File not found: {e}")
|
|
871
|
+
ctx.exit(1)
|
|
872
|
+
except json.JSONDecodeError as e:
|
|
873
|
+
print_error(f"Invalid JSON: {e}")
|
|
874
|
+
ctx.exit(1)
|
|
875
|
+
except ValueError as e:
|
|
876
|
+
print_error(f"Validation error: {e}")
|
|
877
|
+
ctx.exit(1)
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
@click.command()
|
|
881
|
+
@click.option(
|
|
882
|
+
"--annotations",
|
|
883
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
884
|
+
required=True,
|
|
885
|
+
help="Path to annotations (JSONL with 'rater_id' and 'label' fields)",
|
|
886
|
+
)
|
|
887
|
+
@click.option(
|
|
888
|
+
"--metric",
|
|
889
|
+
type=click.Choice(
|
|
890
|
+
[
|
|
891
|
+
"krippendorff_alpha",
|
|
892
|
+
"fleiss_kappa",
|
|
893
|
+
"cohens_kappa",
|
|
894
|
+
"percentage_agreement",
|
|
895
|
+
]
|
|
896
|
+
),
|
|
897
|
+
default="krippendorff_alpha",
|
|
898
|
+
help="Agreement metric to compute",
|
|
899
|
+
)
|
|
900
|
+
@click.option(
|
|
901
|
+
"--data-type",
|
|
902
|
+
type=click.Choice(["nominal", "ordinal", "interval", "ratio"]),
|
|
903
|
+
default="nominal",
|
|
904
|
+
help="Data type for Krippendorff's alpha",
|
|
905
|
+
)
|
|
906
|
+
@click.option(
|
|
907
|
+
"--output",
|
|
908
|
+
type=click.Path(path_type=Path),
|
|
909
|
+
help="Output path for agreement report (JSON)",
|
|
910
|
+
)
|
|
911
|
+
@click.pass_context
|
|
912
|
+
def compute_agreement(
|
|
913
|
+
ctx: click.Context,
|
|
914
|
+
annotations: Path,
|
|
915
|
+
metric: str,
|
|
916
|
+
data_type: str,
|
|
917
|
+
output: Path | None,
|
|
918
|
+
) -> None:
|
|
919
|
+
r"""Compute inter-annotator agreement.
|
|
920
|
+
|
|
921
|
+
Calculates agreement metrics (Cohen's kappa, Fleiss' kappa, Krippendorff's
|
|
922
|
+
alpha, or percentage agreement) from multi-rater annotations.
|
|
923
|
+
|
|
924
|
+
Parameters
|
|
925
|
+
----------
|
|
926
|
+
ctx : click.Context
|
|
927
|
+
Click context object.
|
|
928
|
+
annotations : Path
|
|
929
|
+
Path to annotations file (JSONL).
|
|
930
|
+
metric : str
|
|
931
|
+
Agreement metric to compute.
|
|
932
|
+
data_type : str
|
|
933
|
+
Data type for Krippendorff's alpha.
|
|
934
|
+
output : Path | None
|
|
935
|
+
Output path for report (JSON).
|
|
936
|
+
|
|
937
|
+
Examples
|
|
938
|
+
--------
|
|
939
|
+
$ bead training compute-agreement \\
|
|
940
|
+
--annotations data/annotations.jsonl \\
|
|
941
|
+
--metric krippendorff_alpha \\
|
|
942
|
+
--data-type nominal \\
|
|
943
|
+
--output agreement_report.json
|
|
944
|
+
|
|
945
|
+
$ bead training compute-agreement \\
|
|
946
|
+
--annotations data/annotations.jsonl \\
|
|
947
|
+
--metric cohens_kappa
|
|
948
|
+
"""
|
|
949
|
+
try:
|
|
950
|
+
print_info(f"Computing {metric.replace('_', ' ').title()}")
|
|
951
|
+
|
|
952
|
+
# Load annotations
|
|
953
|
+
with open(annotations, encoding="utf-8") as f:
|
|
954
|
+
annotation_records = [json.loads(line) for line in f if line.strip()]
|
|
955
|
+
|
|
956
|
+
print_info(f"Loaded {len(annotation_records)} annotation records")
|
|
957
|
+
|
|
958
|
+
# Organize annotations by rater
|
|
959
|
+
rater_annotations: dict[str, list[str | int | float]] = {}
|
|
960
|
+
for record in annotation_records:
|
|
961
|
+
rater_id = str(record.get("rater_id", "unknown"))
|
|
962
|
+
label = record.get("label")
|
|
963
|
+
if rater_id not in rater_annotations:
|
|
964
|
+
rater_annotations[rater_id] = []
|
|
965
|
+
rater_annotations[rater_id].append(label)
|
|
966
|
+
|
|
967
|
+
n_raters = len(rater_annotations)
|
|
968
|
+
print_info(f"Found {n_raters} raters")
|
|
969
|
+
|
|
970
|
+
# Compute agreement metric
|
|
971
|
+
agreement_score: float
|
|
972
|
+
if metric == "percentage_agreement":
|
|
973
|
+
if n_raters != 2:
|
|
974
|
+
print_error("Percentage agreement requires exactly 2 raters")
|
|
975
|
+
ctx.exit(1)
|
|
976
|
+
rater_ids = list(rater_annotations.keys())
|
|
977
|
+
agreement_score = InterAnnotatorMetrics.percentage_agreement(
|
|
978
|
+
rater_annotations[rater_ids[0]], rater_annotations[rater_ids[1]]
|
|
979
|
+
)
|
|
980
|
+
elif metric == "cohens_kappa":
|
|
981
|
+
if n_raters != 2:
|
|
982
|
+
print_error("Cohen's kappa requires exactly 2 raters")
|
|
983
|
+
ctx.exit(1)
|
|
984
|
+
rater_ids = list(rater_annotations.keys())
|
|
985
|
+
agreement_score = InterAnnotatorMetrics.cohens_kappa(
|
|
986
|
+
rater_annotations[rater_ids[0]], rater_annotations[rater_ids[1]]
|
|
987
|
+
)
|
|
988
|
+
elif metric == "fleiss_kappa":
|
|
989
|
+
# Convert to ratings matrix format
|
|
990
|
+
# Matrix shape: (n_items, n_categories)
|
|
991
|
+
all_labels = set()
|
|
992
|
+
for labels in rater_annotations.values():
|
|
993
|
+
all_labels.update(labels)
|
|
994
|
+
categories = sorted(all_labels)
|
|
995
|
+
n_items = len(next(iter(rater_annotations.values())))
|
|
996
|
+
|
|
997
|
+
ratings_matrix = np.zeros((n_items, len(categories)), dtype=int)
|
|
998
|
+
for labels in rater_annotations.values():
|
|
999
|
+
for item_idx, label in enumerate(labels):
|
|
1000
|
+
cat_idx = categories.index(label)
|
|
1001
|
+
ratings_matrix[item_idx, cat_idx] += 1
|
|
1002
|
+
|
|
1003
|
+
agreement_score = InterAnnotatorMetrics.fleiss_kappa(
|
|
1004
|
+
cast(np.ndarray[int, np.dtype[np.int_]], ratings_matrix) # type: ignore[misc,valid-type]
|
|
1005
|
+
)
|
|
1006
|
+
elif metric == "krippendorff_alpha":
|
|
1007
|
+
agreement_score = InterAnnotatorMetrics.krippendorff_alpha(
|
|
1008
|
+
rater_annotations, metric=data_type
|
|
1009
|
+
)
|
|
1010
|
+
else:
|
|
1011
|
+
print_error(f"Unknown metric: {metric}")
|
|
1012
|
+
ctx.exit(1)
|
|
1013
|
+
|
|
1014
|
+
# Display result
|
|
1015
|
+
table = Table(title="Inter-Annotator Agreement")
|
|
1016
|
+
table.add_column("Metric", style="cyan")
|
|
1017
|
+
table.add_column("Value", style="green", justify="right")
|
|
1018
|
+
table.add_column("Interpretation", style="yellow")
|
|
1019
|
+
|
|
1020
|
+
# Interpretation guidelines (Landis & Koch, 1977)
|
|
1021
|
+
if agreement_score < 0:
|
|
1022
|
+
interpretation = "Poor"
|
|
1023
|
+
elif agreement_score < 0.2:
|
|
1024
|
+
interpretation = "Slight"
|
|
1025
|
+
elif agreement_score < 0.4:
|
|
1026
|
+
interpretation = "Fair"
|
|
1027
|
+
elif agreement_score < 0.6:
|
|
1028
|
+
interpretation = "Moderate"
|
|
1029
|
+
elif agreement_score < 0.8:
|
|
1030
|
+
interpretation = "Substantial"
|
|
1031
|
+
else:
|
|
1032
|
+
interpretation = "Almost Perfect"
|
|
1033
|
+
|
|
1034
|
+
table.add_row(
|
|
1035
|
+
metric.replace("_", " ").title(),
|
|
1036
|
+
f"{agreement_score:.4f}",
|
|
1037
|
+
interpretation,
|
|
1038
|
+
)
|
|
1039
|
+
table.add_row("N Raters", str(n_raters), "")
|
|
1040
|
+
table.add_row("N Items", str(len(annotation_records) // n_raters), "")
|
|
1041
|
+
|
|
1042
|
+
console.print(table)
|
|
1043
|
+
|
|
1044
|
+
# Save results
|
|
1045
|
+
if output:
|
|
1046
|
+
output.parent.mkdir(parents=True, exist_ok=True)
|
|
1047
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
1048
|
+
data_type_value = data_type if metric == "krippendorff_alpha" else None
|
|
1049
|
+
json.dump(
|
|
1050
|
+
{
|
|
1051
|
+
"metric": metric,
|
|
1052
|
+
"data_type": data_type_value,
|
|
1053
|
+
"score": agreement_score,
|
|
1054
|
+
"interpretation": interpretation,
|
|
1055
|
+
"n_raters": n_raters,
|
|
1056
|
+
"n_items": len(annotation_records) // n_raters,
|
|
1057
|
+
},
|
|
1058
|
+
f,
|
|
1059
|
+
indent=2,
|
|
1060
|
+
)
|
|
1061
|
+
print_success(f"Agreement report saved: {output}")
|
|
1062
|
+
|
|
1063
|
+
except FileNotFoundError as e:
|
|
1064
|
+
print_error(f"File not found: {e}")
|
|
1065
|
+
ctx.exit(1)
|
|
1066
|
+
except json.JSONDecodeError as e:
|
|
1067
|
+
print_error(f"Invalid JSON: {e}")
|
|
1068
|
+
ctx.exit(1)
|
|
1069
|
+
except ValueError as e:
|
|
1070
|
+
print_error(f"Validation error: {e}")
|
|
1071
|
+
ctx.exit(1)
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
# Register commands
|
|
1075
|
+
training.add_command(collect_data)
|
|
1076
|
+
training.add_command(show_data_stats)
|
|
1077
|
+
training.add_command(evaluate)
|
|
1078
|
+
training.add_command(cross_validate)
|
|
1079
|
+
training.add_command(learning_curve)
|
|
1080
|
+
training.add_command(compute_agreement)
|