bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
bead/cli/training.py ADDED
@@ -0,0 +1,1080 @@
1
+ """Training commands for bead CLI.
2
+
3
+ This module provides commands for collecting data, training judgment prediction
4
+ models, and evaluating model performance (Stage 6 of the bead pipeline).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from pathlib import Path
11
+ from typing import cast
12
+
13
+ import click
14
+ import numpy as np
15
+ from rich.console import Console
16
+ from rich.progress import Progress, SpinnerColumn, TextColumn, track
17
+ from rich.table import Table
18
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
19
+ from sklearn.model_selection import KFold
20
+
21
+ from bead.cli.models import _import_class # type: ignore[attr-defined]
22
+ from bead.cli.utils import print_error, print_info, print_success
23
+ from bead.data.base import JsonValue
24
+ from bead.data.serialization import read_jsonlines
25
+ from bead.data_collection.jatos import JATOSDataCollector
26
+ from bead.evaluation.interannotator import InterAnnotatorMetrics
27
+ from bead.items.item import Item
28
+
29
+ console = Console()
30
+
31
+
32
+ @click.group()
33
+ def training() -> None:
34
+ r"""Training commands (Stage 6).
35
+
36
+ Commands for collecting data and training judgment prediction models.
37
+
38
+ \b
39
+ Examples:
40
+ $ bead training collect-data results.jsonl \\
41
+ --jatos-url https://jatos.example.com \\
42
+ --api-token TOKEN --study-id 123
43
+ $ bead training show-data-stats results.jsonl
44
+ """
45
+
46
+
47
+ @click.command()
48
+ @click.argument("output_file", type=click.Path(path_type=Path))
49
+ @click.option("--jatos-url", required=True, help="JATOS server URL")
50
+ @click.option("--api-token", required=True, help="JATOS API token")
51
+ @click.option("--study-id", required=True, type=int, help="JATOS study ID")
52
+ @click.option("--component-id", type=int, help="Filter by component ID (optional)")
53
+ @click.option("--worker-type", help="Filter by worker type (optional)")
54
+ @click.pass_context
55
+ def collect_data(
56
+ ctx: click.Context,
57
+ output_file: Path,
58
+ jatos_url: str,
59
+ api_token: str,
60
+ study_id: int,
61
+ component_id: int | None,
62
+ worker_type: str | None,
63
+ ) -> None:
64
+ r"""Collect judgment data from JATOS.
65
+
66
+ Parameters
67
+ ----------
68
+ ctx : click.Context
69
+ Click context object.
70
+ output_file : Path
71
+ Output path for collected data.
72
+ jatos_url : str
73
+ JATOS server URL.
74
+ api_token : str
75
+ JATOS API token.
76
+ study_id : int
77
+ JATOS study ID.
78
+ component_id : int | None
79
+ Component ID to filter by.
80
+ worker_type : str | None
81
+ Worker type to filter by.
82
+
83
+ Examples
84
+ --------
85
+ $ bead training collect-data results.jsonl \\
86
+ --jatos-url https://jatos.example.com \\
87
+ --api-token my-token \\
88
+ --study-id 123
89
+
90
+ $ bead training collect-data results.jsonl \\
91
+ --jatos-url https://jatos.example.com \\
92
+ --api-token my-token \\
93
+ --study-id 123 \\
94
+ --component-id 456 \\
95
+ --worker-type Prolific
96
+ """
97
+ try:
98
+ print_info(f"Collecting data from JATOS study {study_id}")
99
+
100
+ with Progress(
101
+ SpinnerColumn(),
102
+ TextColumn("[progress.description]{task.description}"),
103
+ console=console,
104
+ ) as progress:
105
+ progress.add_task("Downloading results from JATOS...", total=None)
106
+
107
+ collector = JATOSDataCollector(
108
+ base_url=jatos_url,
109
+ api_token=api_token,
110
+ study_id=study_id,
111
+ )
112
+
113
+ results = collector.download_results(
114
+ output_path=output_file,
115
+ component_id=component_id,
116
+ worker_type=worker_type,
117
+ )
118
+
119
+ print_success(f"Collected {len(results)} results: {output_file}")
120
+
121
+ except Exception as e:
122
+ print_error(f"Failed to collect data: {e}")
123
+ ctx.exit(1)
124
+
125
+
126
+ @click.command()
127
+ @click.argument("data_file", type=click.Path(exists=True, path_type=Path))
128
+ @click.pass_context
129
+ def show_data_stats(ctx: click.Context, data_file: Path) -> None:
130
+ """Show statistics about collected data.
131
+
132
+ Parameters
133
+ ----------
134
+ ctx : click.Context
135
+ Click context object.
136
+ data_file : Path
137
+ Path to data file.
138
+
139
+ Examples
140
+ --------
141
+ $ bead training show-data-stats results.jsonl
142
+ """
143
+ try:
144
+ print_info(f"Analyzing data: {data_file}")
145
+
146
+ # Load and analyze data
147
+ results: list[dict[str, JsonValue]] = []
148
+ with open(data_file, encoding="utf-8") as f:
149
+ for line in f:
150
+ line = line.strip()
151
+ if not line:
152
+ continue
153
+ result: dict[str, JsonValue] = json.loads(line)
154
+ results.append(result)
155
+
156
+ if not results:
157
+ print_error("No data found in file")
158
+ ctx.exit(1)
159
+
160
+ # Calculate statistics
161
+ total_results = len(results)
162
+
163
+ # Count unique workers if available
164
+ worker_ids: set[str] = set()
165
+ for result in results:
166
+ if "worker_id" in result and isinstance(result["worker_id"], str):
167
+ worker_ids.add(result["worker_id"])
168
+
169
+ # Count response types if available
170
+ response_types: dict[str, int] = {}
171
+ for result in results:
172
+ if "data" in result:
173
+ data: JsonValue = result["data"]
174
+ if isinstance(data, dict):
175
+ for key in data.keys(): # type: ignore[var-annotated]
176
+ key_str = str(key) # type: ignore[arg-type]
177
+ response_types[key_str] = response_types.get(key_str, 0) + 1
178
+
179
+ # Display statistics
180
+ table = Table(title="Data Statistics")
181
+ table.add_column("Metric", style="cyan")
182
+ table.add_column("Value", style="green", justify="right")
183
+
184
+ table.add_row("Total Results", str(total_results))
185
+ if worker_ids:
186
+ table.add_row("Unique Workers", str(len(worker_ids)))
187
+
188
+ if response_types:
189
+ table.add_row("", "") # Separator
190
+ for resp_type, count in sorted(response_types.items()):
191
+ table.add_row(f"Response Type: {resp_type}", str(count))
192
+
193
+ console.print(table)
194
+
195
+ except json.JSONDecodeError as e:
196
+ print_error(f"Invalid JSON in data file: {e}")
197
+ ctx.exit(1)
198
+ except Exception as e:
199
+ print_error(f"Failed to show statistics: {e}")
200
+ ctx.exit(1)
201
+
202
+
203
+ @click.command()
204
+ @click.option(
205
+ "--model-dir",
206
+ type=click.Path(exists=True, file_okay=False, path_type=Path),
207
+ required=True,
208
+ help="Directory containing trained model",
209
+ )
210
+ @click.option(
211
+ "--test-items",
212
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
213
+ required=True,
214
+ help="Path to test items (JSONL)",
215
+ )
216
+ @click.option(
217
+ "--test-labels",
218
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
219
+ required=True,
220
+ help="Path to test labels (JSONL, one label per line)",
221
+ )
222
+ @click.option(
223
+ "--participant-ids",
224
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
225
+ help="Path to participant IDs (JSONL, one ID per line, optional)",
226
+ )
227
+ @click.option(
228
+ "--metrics",
229
+ default="accuracy,precision,recall,f1",
230
+ help="Comma-separated list of metrics (accuracy,precision,recall,f1)",
231
+ )
232
+ @click.option(
233
+ "--average",
234
+ type=click.Choice(["macro", "micro", "weighted"]),
235
+ default="macro",
236
+ help="Averaging strategy for multi-class metrics",
237
+ )
238
+ @click.option(
239
+ "--output",
240
+ type=click.Path(path_type=Path),
241
+ help="Output path for evaluation report (JSON)",
242
+ )
243
+ @click.pass_context
244
+ def evaluate(
245
+ ctx: click.Context,
246
+ model_dir: Path,
247
+ test_items: Path,
248
+ test_labels: Path,
249
+ participant_ids: Path | None,
250
+ metrics: str,
251
+ average: str,
252
+ output: Path | None,
253
+ ) -> None:
254
+ r"""Evaluate trained model on test set.
255
+
256
+ Loads a trained model and computes evaluation metrics (accuracy, precision,
257
+ recall, F1) on a held-out test set.
258
+
259
+ Parameters
260
+ ----------
261
+ ctx : click.Context
262
+ Click context object.
263
+ model_dir : Path
264
+ Directory containing trained model.
265
+ test_items : Path
266
+ Path to test items (JSONL).
267
+ test_labels : Path
268
+ Path to test labels (JSONL, one label per line).
269
+ participant_ids : Path | None
270
+ Path to participant IDs (optional, for random effects models).
271
+ metrics : str
272
+ Comma-separated list of metrics to compute.
273
+ average : str
274
+ Averaging strategy for multi-class metrics.
275
+ output : Path | None
276
+ Output path for evaluation report (JSON).
277
+
278
+ Examples
279
+ --------
280
+ $ bead training evaluate \\
281
+ --model-dir models/my_model/ \\
282
+ --test-items data/test_items.jsonl \\
283
+ --test-labels data/test_labels.jsonl \\
284
+ --metrics accuracy,f1 \\
285
+ --output evaluation_report.json
286
+ """
287
+ try:
288
+ print_info(f"Evaluating model: {model_dir}")
289
+
290
+ # Load model config
291
+ config_path = model_dir / "config.json"
292
+ if not config_path.exists():
293
+ print_error(f"Model config not found: {config_path}")
294
+ ctx.exit(1)
295
+
296
+ with open(config_path, encoding="utf-8") as f:
297
+ model_config = json.load(f)
298
+
299
+ task_type = model_config.get("task_type")
300
+ if not task_type:
301
+ print_error("Model config missing 'task_type' field")
302
+ ctx.exit(1)
303
+
304
+ # Load test items
305
+ items_list = read_jsonlines(test_items, Item)
306
+ print_info(f"Loaded {len(items_list)} test items")
307
+
308
+ # Load test labels
309
+ with open(test_labels, encoding="utf-8") as f:
310
+ labels: list[str | int | float] = [
311
+ json.loads(line.strip()) for line in f if line.strip()
312
+ ]
313
+
314
+ if len(items_list) != len(labels):
315
+ print_error(f"Mismatch: {len(items_list)} items but {len(labels)} labels")
316
+ ctx.exit(1)
317
+
318
+ # Load participant IDs if provided
319
+ participant_ids_list: list[str] | None = None
320
+ if participant_ids:
321
+ with open(participant_ids, encoding="utf-8") as f:
322
+ participant_ids_list = [
323
+ json.loads(line.strip()) for line in f if line.strip()
324
+ ]
325
+ if len(participant_ids_list) != len(items_list):
326
+ print_error(
327
+ f"Mismatch: {len(items_list)} items "
328
+ f"but {len(participant_ids_list)} participant IDs"
329
+ )
330
+ ctx.exit(1)
331
+
332
+ # Load model
333
+ model_class_name = f"{task_type.title().replace('_', '')}Model"
334
+ model_module = f"bead.active_learning.models.{task_type}"
335
+ model_class = _import_class(f"{model_module}.{model_class_name}")
336
+
337
+ model_instance = model_class.load(model_dir)
338
+ print_success(f"Loaded model from {model_dir}")
339
+
340
+ # Make predictions
341
+ with Progress(
342
+ SpinnerColumn(),
343
+ TextColumn("[progress.description]{task.description}"),
344
+ console=console,
345
+ ) as progress:
346
+ progress.add_task("Making predictions...", total=None)
347
+ predictions = model_instance.predict(items_list, participant_ids_list)
348
+
349
+ # Compute requested metrics
350
+ metrics_list = [m.strip().lower() for m in metrics.split(",")]
351
+ results: dict[str, float] = {}
352
+
353
+ for metric_name in metrics_list:
354
+ if metric_name == "accuracy":
355
+ acc = accuracy_score(labels, predictions)
356
+ results["accuracy"] = acc
357
+ elif metric_name in ["precision", "recall", "f1"]:
358
+ precision, recall, f1, support = precision_recall_fscore_support(
359
+ labels, predictions, average=average, zero_division=0.0
360
+ )
361
+ if "precision" not in results:
362
+ results["precision"] = float(precision)
363
+ results["recall"] = float(recall)
364
+ results["f1"] = float(f1)
365
+ # support is None when using averaging
366
+ if support is not None:
367
+ results["support"] = (
368
+ float(support)
369
+ if isinstance(support, int | float)
370
+ else float(sum(support))
371
+ )
372
+ else:
373
+ print_error(f"Unknown metric: {metric_name}")
374
+ ctx.exit(1)
375
+
376
+ # Display results
377
+ table = Table(title="Evaluation Results")
378
+ table.add_column("Metric", style="cyan")
379
+ table.add_column("Value", style="green", justify="right")
380
+
381
+ for metric_name, value in results.items():
382
+ if metric_name == "support":
383
+ table.add_row(metric_name.capitalize(), f"{int(value)}")
384
+ else:
385
+ table.add_row(metric_name.capitalize(), f"{value:.4f}")
386
+
387
+ console.print(table)
388
+
389
+ # Save to file if requested
390
+ if output:
391
+ output.parent.mkdir(parents=True, exist_ok=True)
392
+ with open(output, "w", encoding="utf-8") as f:
393
+ json.dump(
394
+ {
395
+ "model_dir": str(model_dir),
396
+ "test_items": str(test_items),
397
+ "test_labels": str(test_labels),
398
+ "metrics": results,
399
+ "average": average,
400
+ },
401
+ f,
402
+ indent=2,
403
+ )
404
+ print_success(f"Evaluation report saved: {output}")
405
+
406
+ except FileNotFoundError as e:
407
+ print_error(f"File not found: {e}")
408
+ ctx.exit(1)
409
+ except json.JSONDecodeError as e:
410
+ print_error(f"Invalid JSON: {e}")
411
+ ctx.exit(1)
412
+ except ValueError as e:
413
+ print_error(f"Validation error: {e}")
414
+ ctx.exit(1)
415
+ except ImportError as e:
416
+ print_error(f"Failed to import model class: {e}")
417
+ ctx.exit(1)
418
+
419
+
420
+ @click.command()
421
+ @click.option(
422
+ "--items",
423
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
424
+ required=True,
425
+ help="Path to items (JSONL)",
426
+ )
427
+ @click.option(
428
+ "--labels",
429
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
430
+ required=True,
431
+ help="Path to labels (JSONL, one label per line)",
432
+ )
433
+ @click.option(
434
+ "--participant-ids",
435
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
436
+ help="Path to participant IDs (JSONL, optional)",
437
+ )
438
+ @click.option(
439
+ "--model-config",
440
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
441
+ required=True,
442
+ help="Path to model configuration (JSON/YAML)",
443
+ )
444
+ @click.option(
445
+ "--k-folds",
446
+ type=int,
447
+ default=5,
448
+ help="Number of folds for cross-validation",
449
+ )
450
+ @click.option(
451
+ "--stratify-by",
452
+ type=click.Choice(["participant_id", "label", "none"]),
453
+ default="none",
454
+ help="Stratification strategy",
455
+ )
456
+ @click.option(
457
+ "--random-seed",
458
+ type=int,
459
+ help="Random seed for reproducibility",
460
+ )
461
+ @click.option(
462
+ "--output",
463
+ type=click.Path(path_type=Path),
464
+ help="Output path for CV results (JSON)",
465
+ )
466
+ @click.pass_context
467
+ def cross_validate(
468
+ ctx: click.Context,
469
+ items: Path,
470
+ labels: Path,
471
+ participant_ids: Path | None,
472
+ model_config: Path,
473
+ k_folds: int,
474
+ stratify_by: str,
475
+ random_seed: int | None,
476
+ output: Path | None,
477
+ ) -> None:
478
+ r"""Perform K-fold cross-validation.
479
+
480
+ Trains model with K-fold cross-validation and reports metrics for each fold.
481
+
482
+ Parameters
483
+ ----------
484
+ ctx : click.Context
485
+ Click context object.
486
+ items : Path
487
+ Path to items (JSONL).
488
+ labels : Path
489
+ Path to labels (JSONL).
490
+ participant_ids : Path | None
491
+ Path to participant IDs (optional).
492
+ model_config : Path
493
+ Path to model configuration file.
494
+ k_folds : int
495
+ Number of folds.
496
+ stratify_by : str
497
+ Stratification strategy.
498
+ random_seed : int | None
499
+ Random seed for reproducibility.
500
+ output : Path | None
501
+ Output path for results (JSON).
502
+
503
+ Examples
504
+ --------
505
+ $ bead training cross-validate \\
506
+ --items data/items.jsonl \\
507
+ --labels data/labels.jsonl \\
508
+ --model-config config.yaml \\
509
+ --k-folds 5 \\
510
+ --stratify-by label \\
511
+ --output cv_results.json
512
+ """
513
+ try:
514
+ print_info(f"Running {k_folds}-fold cross-validation")
515
+
516
+ # Load items
517
+ items_list = read_jsonlines(items, Item)
518
+ print_info(f"Loaded {len(items_list)} items")
519
+
520
+ # Load labels
521
+ with open(labels, encoding="utf-8") as f:
522
+ labels_list: list[JsonValue] = [
523
+ json.loads(line.strip()) for line in f if line.strip()
524
+ ]
525
+
526
+ if len(items_list) != len(labels_list):
527
+ print_error(
528
+ f"Mismatch: {len(items_list)} items but {len(labels_list)} labels"
529
+ )
530
+ ctx.exit(1)
531
+
532
+ # Load participant IDs if provided
533
+ participant_ids_list: list[str] | None = None
534
+ if participant_ids:
535
+ with open(participant_ids, encoding="utf-8") as f:
536
+ participant_ids_list = [
537
+ json.loads(line.strip()) for line in f if line.strip()
538
+ ]
539
+ if len(participant_ids_list) != len(items_list):
540
+ print_error(
541
+ f"Mismatch: {len(items_list)} items "
542
+ f"but {len(participant_ids_list)} participant IDs"
543
+ )
544
+ ctx.exit(1)
545
+
546
+ # Load model config
547
+ with open(model_config, encoding="utf-8") as f:
548
+ config_dict = json.load(f)
549
+
550
+ task_type = config_dict.get("task_type")
551
+ if not task_type:
552
+ print_error("Model config missing 'task_type' field")
553
+ ctx.exit(1)
554
+
555
+ # Import model and config classes
556
+ model_class_name = f"{task_type.title().replace('_', '')}Model"
557
+ config_class_name = f"{task_type.title().replace('_', '')}ModelConfig"
558
+ model_module = f"bead.active_learning.models.{task_type}"
559
+ config_module = "bead.config.active_learning"
560
+
561
+ model_class = _import_class(f"{model_module}.{model_class_name}")
562
+ config_class = _import_class(f"{config_module}.{config_class_name}")
563
+
564
+ # Create cross-validator
565
+ cv = KFold(n_splits=k_folds, shuffle=True, random_state=random_seed)
566
+
567
+ # Generate fold indices
568
+ fold_indices = list(cv.split(items_list))
569
+
570
+ print_info(f"Generated {len(fold_indices)} folds")
571
+
572
+ # Train and evaluate on each fold
573
+ fold_results: list[dict[str, float | int]] = []
574
+
575
+ for fold_idx, (train_indices, test_indices) in enumerate(fold_indices, start=1):
576
+ print_info(f"\n[Fold {fold_idx}/{k_folds}]")
577
+ print_info(f" Train: {len(train_indices)} items")
578
+ print_info(f" Test: {len(test_indices)} items")
579
+
580
+ # Get items for train and test sets
581
+ train_items = [items_list[i] for i in train_indices]
582
+ test_items = [items_list[i] for i in test_indices]
583
+
584
+ # Get labels for this fold
585
+ train_labels = [labels_list[i] for i in train_indices]
586
+ test_labels = [labels_list[i] for i in test_indices]
587
+
588
+ # Get participant IDs for this fold (if provided)
589
+ train_pids: list[str] | None = None
590
+ test_pids: list[str] | None = None
591
+ if participant_ids_list is not None:
592
+ train_pids = [participant_ids_list[i] for i in train_indices]
593
+ test_pids = [participant_ids_list[i] for i in test_indices]
594
+
595
+ # Create and train model for this fold
596
+ print_info(" Training model...")
597
+ model_config_obj = config_class(**config_dict)
598
+ model_instance = model_class(config=model_config_obj)
599
+ model_instance.train(train_items, train_labels, participant_ids=train_pids)
600
+
601
+ # Make predictions on test set
602
+ predictions = model_instance.predict(test_items, participant_ids=test_pids)
603
+ pred_labels = [p.predicted_class for p in predictions]
604
+
605
+ # Compute metrics
606
+ accuracy = accuracy_score(test_labels, pred_labels)
607
+ precision, recall, f1, support = precision_recall_fscore_support(
608
+ test_labels, pred_labels, average="macro", zero_division=0.0
609
+ )
610
+ prf: dict[str, float] = {
611
+ "precision": float(precision),
612
+ "recall": float(recall),
613
+ "f1": float(f1),
614
+ }
615
+ # support is None when using averaging
616
+ if support is not None:
617
+ prf["support"] = (
618
+ float(support)
619
+ if isinstance(support, int | float)
620
+ else float(sum(support))
621
+ )
622
+
623
+ fold_result: dict[str, float | int] = {
624
+ "fold": fold_idx,
625
+ "accuracy": float(accuracy),
626
+ "precision": prf["precision"],
627
+ "recall": prf["recall"],
628
+ "f1": prf["f1"],
629
+ }
630
+ if "support" in prf:
631
+ fold_result["support"] = prf["support"]
632
+ fold_results.append(fold_result)
633
+
634
+ print_success(f" Accuracy: {accuracy:.4f}, F1: {prf['f1']:.4f}")
635
+
636
+ # Compute average metrics
637
+ avg_results = {
638
+ "accuracy": np.mean([r["accuracy"] for r in fold_results]),
639
+ "precision": np.mean([r["precision"] for r in fold_results]),
640
+ "recall": np.mean([r["recall"] for r in fold_results]),
641
+ "f1": np.mean([r["f1"] for r in fold_results]),
642
+ }
643
+
644
+ # Display summary
645
+ console.rule("[bold]Cross-Validation Summary[/bold]")
646
+ table = Table()
647
+ table.add_column("Metric", style="cyan")
648
+ table.add_column("Mean", style="green", justify="right")
649
+ table.add_column("Std", style="yellow", justify="right")
650
+
651
+ for metric_name in ["accuracy", "precision", "recall", "f1"]:
652
+ values = [r[metric_name] for r in fold_results]
653
+ mean_val = np.mean(values)
654
+ std_val = np.std(values)
655
+ table.add_row(metric_name.capitalize(), f"{mean_val:.4f}", f"{std_val:.4f}")
656
+
657
+ console.print(table)
658
+
659
+ # Save results
660
+ if output:
661
+ output.parent.mkdir(parents=True, exist_ok=True)
662
+ with open(output, "w", encoding="utf-8") as f:
663
+ json.dump(
664
+ {
665
+ "k_folds": k_folds,
666
+ "stratify_by": stratify_by,
667
+ "fold_results": fold_results,
668
+ "average_metrics": avg_results,
669
+ },
670
+ f,
671
+ indent=2,
672
+ )
673
+ print_success(f"CV results saved: {output}")
674
+
675
+ except FileNotFoundError as e:
676
+ print_error(f"File not found: {e}")
677
+ ctx.exit(1)
678
+ except json.JSONDecodeError as e:
679
+ print_error(f"Invalid JSON: {e}")
680
+ ctx.exit(1)
681
+ except ValueError as e:
682
+ print_error(f"Validation error: {e}")
683
+ ctx.exit(1)
684
+
685
+
686
+ @click.command()
687
+ @click.option(
688
+ "--items",
689
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
690
+ required=True,
691
+ help="Path to items (JSONL)",
692
+ )
693
+ @click.option(
694
+ "--labels",
695
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
696
+ required=True,
697
+ help="Path to labels (JSONL)",
698
+ )
699
+ @click.option(
700
+ "--model-config",
701
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
702
+ required=True,
703
+ help="Path to model configuration",
704
+ )
705
+ @click.option(
706
+ "--train-sizes",
707
+ default="0.1,0.2,0.5,0.8,1.0",
708
+ help="Comma-separated training set sizes (fractions)",
709
+ )
710
+ @click.option(
711
+ "--random-seed",
712
+ type=int,
713
+ help="Random seed for reproducibility",
714
+ )
715
+ @click.option(
716
+ "--output",
717
+ type=click.Path(path_type=Path),
718
+ help="Output path for learning curve data (JSON)",
719
+ )
720
+ @click.pass_context
721
+ def learning_curve(
722
+ ctx: click.Context,
723
+ items: Path,
724
+ labels: Path,
725
+ model_config: Path,
726
+ train_sizes: str,
727
+ random_seed: int | None,
728
+ output: Path | None,
729
+ ) -> None:
730
+ r"""Generate learning curve with varying training set sizes.
731
+
732
+ Trains models with increasing amounts of training data and plots
733
+ training/validation performance.
734
+
735
+ Parameters
736
+ ----------
737
+ ctx : click.Context
738
+ Click context object.
739
+ items : Path
740
+ Path to items (JSONL).
741
+ labels : Path
742
+ Path to labels (JSONL).
743
+ model_config : Path
744
+ Path to model configuration.
745
+ train_sizes : str
746
+ Comma-separated training set sizes (fractions).
747
+ random_seed : int | None
748
+ Random seed for reproducibility.
749
+ output : Path | None
750
+ Output path for results (JSON).
751
+
752
+ Examples
753
+ --------
754
+ $ bead training learning-curve \\
755
+ --items data/items.jsonl \\
756
+ --labels data/labels.jsonl \\
757
+ --model-config config.yaml \\
758
+ --train-sizes 0.1,0.2,0.5,1.0 \\
759
+ --output learning_curve.json
760
+ """
761
+ try:
762
+ print_info("Generating learning curve")
763
+
764
+ # Load items
765
+ items_list = read_jsonlines(items, Item)
766
+ print_info(f"Loaded {len(items_list)} items")
767
+
768
+ # Load labels
769
+ with open(labels, encoding="utf-8") as f:
770
+ labels_list: list[str | int | float] = [
771
+ json.loads(line.strip()) for line in f if line.strip()
772
+ ]
773
+
774
+ # Load model config
775
+ with open(model_config, encoding="utf-8") as f:
776
+ config_dict = json.load(f)
777
+
778
+ task_type = config_dict.get("task_type")
779
+ if not task_type:
780
+ print_error("Model config missing 'task_type' field")
781
+ ctx.exit(1)
782
+
783
+ # Import model and config classes
784
+ model_class_name = f"{task_type.title().replace('_', '')}Model"
785
+ config_class_name = f"{task_type.title().replace('_', '')}ModelConfig"
786
+ model_module = f"bead.active_learning.models.{task_type}"
787
+ config_module = "bead.config.active_learning"
788
+
789
+ model_class = _import_class(f"{model_module}.{model_class_name}")
790
+ config_class = _import_class(f"{config_module}.{config_class_name}")
791
+
792
+ # Parse train sizes
793
+ sizes = [float(s.strip()) for s in train_sizes.split(",")]
794
+ if any(s <= 0 or s > 1 for s in sizes):
795
+ print_error("Train sizes must be in range (0, 1]")
796
+ ctx.exit(1)
797
+
798
+ # Train with different data sizes
799
+ curve_results: list[dict[str, float]] = []
800
+
801
+ for size in track(sizes, description="Training with varying data sizes"):
802
+ n_samples = int(len(items_list) * size)
803
+ print_info(f"\nTraining with {n_samples} samples ({size:.0%})")
804
+
805
+ # Split into train/test (80/20)
806
+ split_idx = int(n_samples * 0.8)
807
+ train_items_subset = items_list[:split_idx]
808
+ test_items_subset = items_list[split_idx:n_samples]
809
+ train_labels_subset = labels_list[:split_idx]
810
+ test_labels_subset = labels_list[split_idx:n_samples]
811
+
812
+ # Train model
813
+ print_info(" Training...")
814
+ model_config_obj = config_class(**config_dict)
815
+ model_instance = model_class(config=model_config_obj)
816
+ # Note: participant_ids=None for fixed effects models
817
+ model_instance.train(
818
+ train_items_subset, train_labels_subset, participant_ids=None
819
+ )
820
+
821
+ # Make predictions
822
+ train_predictions = model_instance.predict(
823
+ train_items_subset, participant_ids=None
824
+ )
825
+ test_predictions = model_instance.predict(
826
+ test_items_subset, participant_ids=None
827
+ )
828
+
829
+ # Compute metrics
830
+ train_acc = accuracy_score(train_labels_subset, train_predictions)
831
+ test_acc = accuracy_score(test_labels_subset, test_predictions)
832
+
833
+ curve_results.append(
834
+ {
835
+ "train_size": size,
836
+ "n_samples": n_samples,
837
+ "train_accuracy": train_acc,
838
+ "test_accuracy": test_acc,
839
+ }
840
+ )
841
+
842
+ print_success(f" Train acc: {train_acc:.4f}, Test acc: {test_acc:.4f}")
843
+
844
+ # Display summary
845
+ console.rule("[bold]Learning Curve Summary[/bold]")
846
+ table = Table()
847
+ table.add_column("Train Size", style="cyan")
848
+ table.add_column("N Samples", style="blue", justify="right")
849
+ table.add_column("Train Acc", style="green", justify="right")
850
+ table.add_column("Test Acc", style="yellow", justify="right")
851
+
852
+ for result in curve_results:
853
+ table.add_row(
854
+ f"{result['train_size']:.0%}",
855
+ str(result["n_samples"]),
856
+ f"{result['train_accuracy']:.4f}",
857
+ f"{result['test_accuracy']:.4f}",
858
+ )
859
+
860
+ console.print(table)
861
+
862
+ # Save results
863
+ if output:
864
+ output.parent.mkdir(parents=True, exist_ok=True)
865
+ with open(output, "w", encoding="utf-8") as f:
866
+ json.dump({"curve_data": curve_results}, f, indent=2)
867
+ print_success(f"Learning curve data saved: {output}")
868
+
869
+ except FileNotFoundError as e:
870
+ print_error(f"File not found: {e}")
871
+ ctx.exit(1)
872
+ except json.JSONDecodeError as e:
873
+ print_error(f"Invalid JSON: {e}")
874
+ ctx.exit(1)
875
+ except ValueError as e:
876
+ print_error(f"Validation error: {e}")
877
+ ctx.exit(1)
878
+
879
+
880
+ @click.command()
881
+ @click.option(
882
+ "--annotations",
883
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
884
+ required=True,
885
+ help="Path to annotations (JSONL with 'rater_id' and 'label' fields)",
886
+ )
887
+ @click.option(
888
+ "--metric",
889
+ type=click.Choice(
890
+ [
891
+ "krippendorff_alpha",
892
+ "fleiss_kappa",
893
+ "cohens_kappa",
894
+ "percentage_agreement",
895
+ ]
896
+ ),
897
+ default="krippendorff_alpha",
898
+ help="Agreement metric to compute",
899
+ )
900
+ @click.option(
901
+ "--data-type",
902
+ type=click.Choice(["nominal", "ordinal", "interval", "ratio"]),
903
+ default="nominal",
904
+ help="Data type for Krippendorff's alpha",
905
+ )
906
+ @click.option(
907
+ "--output",
908
+ type=click.Path(path_type=Path),
909
+ help="Output path for agreement report (JSON)",
910
+ )
911
+ @click.pass_context
912
+ def compute_agreement(
913
+ ctx: click.Context,
914
+ annotations: Path,
915
+ metric: str,
916
+ data_type: str,
917
+ output: Path | None,
918
+ ) -> None:
919
+ r"""Compute inter-annotator agreement.
920
+
921
+ Calculates agreement metrics (Cohen's kappa, Fleiss' kappa, Krippendorff's
922
+ alpha, or percentage agreement) from multi-rater annotations.
923
+
924
+ Parameters
925
+ ----------
926
+ ctx : click.Context
927
+ Click context object.
928
+ annotations : Path
929
+ Path to annotations file (JSONL).
930
+ metric : str
931
+ Agreement metric to compute.
932
+ data_type : str
933
+ Data type for Krippendorff's alpha.
934
+ output : Path | None
935
+ Output path for report (JSON).
936
+
937
+ Examples
938
+ --------
939
+ $ bead training compute-agreement \\
940
+ --annotations data/annotations.jsonl \\
941
+ --metric krippendorff_alpha \\
942
+ --data-type nominal \\
943
+ --output agreement_report.json
944
+
945
+ $ bead training compute-agreement \\
946
+ --annotations data/annotations.jsonl \\
947
+ --metric cohens_kappa
948
+ """
949
+ try:
950
+ print_info(f"Computing {metric.replace('_', ' ').title()}")
951
+
952
+ # Load annotations
953
+ with open(annotations, encoding="utf-8") as f:
954
+ annotation_records = [json.loads(line) for line in f if line.strip()]
955
+
956
+ print_info(f"Loaded {len(annotation_records)} annotation records")
957
+
958
+ # Organize annotations by rater
959
+ rater_annotations: dict[str, list[str | int | float]] = {}
960
+ for record in annotation_records:
961
+ rater_id = str(record.get("rater_id", "unknown"))
962
+ label = record.get("label")
963
+ if rater_id not in rater_annotations:
964
+ rater_annotations[rater_id] = []
965
+ rater_annotations[rater_id].append(label)
966
+
967
+ n_raters = len(rater_annotations)
968
+ print_info(f"Found {n_raters} raters")
969
+
970
+ # Compute agreement metric
971
+ agreement_score: float
972
+ if metric == "percentage_agreement":
973
+ if n_raters != 2:
974
+ print_error("Percentage agreement requires exactly 2 raters")
975
+ ctx.exit(1)
976
+ rater_ids = list(rater_annotations.keys())
977
+ agreement_score = InterAnnotatorMetrics.percentage_agreement(
978
+ rater_annotations[rater_ids[0]], rater_annotations[rater_ids[1]]
979
+ )
980
+ elif metric == "cohens_kappa":
981
+ if n_raters != 2:
982
+ print_error("Cohen's kappa requires exactly 2 raters")
983
+ ctx.exit(1)
984
+ rater_ids = list(rater_annotations.keys())
985
+ agreement_score = InterAnnotatorMetrics.cohens_kappa(
986
+ rater_annotations[rater_ids[0]], rater_annotations[rater_ids[1]]
987
+ )
988
+ elif metric == "fleiss_kappa":
989
+ # Convert to ratings matrix format
990
+ # Matrix shape: (n_items, n_categories)
991
+ all_labels = set()
992
+ for labels in rater_annotations.values():
993
+ all_labels.update(labels)
994
+ categories = sorted(all_labels)
995
+ n_items = len(next(iter(rater_annotations.values())))
996
+
997
+ ratings_matrix = np.zeros((n_items, len(categories)), dtype=int)
998
+ for labels in rater_annotations.values():
999
+ for item_idx, label in enumerate(labels):
1000
+ cat_idx = categories.index(label)
1001
+ ratings_matrix[item_idx, cat_idx] += 1
1002
+
1003
+ agreement_score = InterAnnotatorMetrics.fleiss_kappa(
1004
+ cast(np.ndarray[int, np.dtype[np.int_]], ratings_matrix) # type: ignore[misc,valid-type]
1005
+ )
1006
+ elif metric == "krippendorff_alpha":
1007
+ agreement_score = InterAnnotatorMetrics.krippendorff_alpha(
1008
+ rater_annotations, metric=data_type
1009
+ )
1010
+ else:
1011
+ print_error(f"Unknown metric: {metric}")
1012
+ ctx.exit(1)
1013
+
1014
+ # Display result
1015
+ table = Table(title="Inter-Annotator Agreement")
1016
+ table.add_column("Metric", style="cyan")
1017
+ table.add_column("Value", style="green", justify="right")
1018
+ table.add_column("Interpretation", style="yellow")
1019
+
1020
+ # Interpretation guidelines (Landis & Koch, 1977)
1021
+ if agreement_score < 0:
1022
+ interpretation = "Poor"
1023
+ elif agreement_score < 0.2:
1024
+ interpretation = "Slight"
1025
+ elif agreement_score < 0.4:
1026
+ interpretation = "Fair"
1027
+ elif agreement_score < 0.6:
1028
+ interpretation = "Moderate"
1029
+ elif agreement_score < 0.8:
1030
+ interpretation = "Substantial"
1031
+ else:
1032
+ interpretation = "Almost Perfect"
1033
+
1034
+ table.add_row(
1035
+ metric.replace("_", " ").title(),
1036
+ f"{agreement_score:.4f}",
1037
+ interpretation,
1038
+ )
1039
+ table.add_row("N Raters", str(n_raters), "")
1040
+ table.add_row("N Items", str(len(annotation_records) // n_raters), "")
1041
+
1042
+ console.print(table)
1043
+
1044
+ # Save results
1045
+ if output:
1046
+ output.parent.mkdir(parents=True, exist_ok=True)
1047
+ with open(output, "w", encoding="utf-8") as f:
1048
+ data_type_value = data_type if metric == "krippendorff_alpha" else None
1049
+ json.dump(
1050
+ {
1051
+ "metric": metric,
1052
+ "data_type": data_type_value,
1053
+ "score": agreement_score,
1054
+ "interpretation": interpretation,
1055
+ "n_raters": n_raters,
1056
+ "n_items": len(annotation_records) // n_raters,
1057
+ },
1058
+ f,
1059
+ indent=2,
1060
+ )
1061
+ print_success(f"Agreement report saved: {output}")
1062
+
1063
+ except FileNotFoundError as e:
1064
+ print_error(f"File not found: {e}")
1065
+ ctx.exit(1)
1066
+ except json.JSONDecodeError as e:
1067
+ print_error(f"Invalid JSON: {e}")
1068
+ ctx.exit(1)
1069
+ except ValueError as e:
1070
+ print_error(f"Validation error: {e}")
1071
+ ctx.exit(1)
1072
+
1073
+
1074
+ # Register commands
1075
+ training.add_command(collect_data)
1076
+ training.add_command(show_data_stats)
1077
+ training.add_command(evaluate)
1078
+ training.add_command(cross_validate)
1079
+ training.add_command(learning_curve)
1080
+ training.add_command(compute_agreement)