mowen-cli 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
__pycache__/
|
|
2
|
+
*.py[cod]
|
|
3
|
+
*$py.class
|
|
4
|
+
*.egg-info/
|
|
5
|
+
*.egg
|
|
6
|
+
dist/
|
|
7
|
+
build/
|
|
8
|
+
.eggs/
|
|
9
|
+
*.whl
|
|
10
|
+
|
|
11
|
+
.pytest_cache/
|
|
12
|
+
.mypy_cache/
|
|
13
|
+
.ruff_cache/
|
|
14
|
+
htmlcov/
|
|
15
|
+
.coverage
|
|
16
|
+
coverage.xml
|
|
17
|
+
|
|
18
|
+
.env
|
|
19
|
+
.venv
|
|
20
|
+
venv/
|
|
21
|
+
env/
|
|
22
|
+
|
|
23
|
+
*.db
|
|
24
|
+
*.sqlite3
|
|
25
|
+
|
|
26
|
+
node_modules/
|
|
27
|
+
web/dist/
|
|
28
|
+
|
|
29
|
+
.idea/
|
|
30
|
+
.vscode/
|
|
31
|
+
*.swp
|
|
32
|
+
*.swo
|
|
33
|
+
*~
|
|
34
|
+
|
|
35
|
+
JGAAP/
|
|
36
|
+
CLAUDE.md
|
|
37
|
+
.gstack/
|
mowen_cli-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mowen-cli
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: CLI for mowen authorship attribution toolkit
|
|
5
|
+
Project-URL: Homepage, https://github.com/jnoecker/mowen
|
|
6
|
+
Project-URL: Repository, https://github.com/jnoecker/mowen
|
|
7
|
+
Author: John Noecker Jr
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Topic :: Text Processing :: Linguistic
|
|
14
|
+
Requires-Python: >=3.11
|
|
15
|
+
Requires-Dist: mowen>=0.1.0
|
|
16
|
+
Requires-Dist: typer>=0.9
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "mowen-cli"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "CLI for mowen authorship attribution toolkit"
|
|
9
|
+
license = "MIT"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"mowen>=0.1.0",
|
|
13
|
+
"typer>=0.9",
|
|
14
|
+
]
|
|
15
|
+
authors = [
|
|
16
|
+
{ name = "John Noecker Jr" },
|
|
17
|
+
]
|
|
18
|
+
classifiers = [
|
|
19
|
+
"Development Status :: 4 - Beta",
|
|
20
|
+
"Intended Audience :: Science/Research",
|
|
21
|
+
"License :: OSI Approved :: MIT License",
|
|
22
|
+
"Programming Language :: Python :: 3",
|
|
23
|
+
"Topic :: Text Processing :: Linguistic",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
[project.urls]
|
|
27
|
+
Homepage = "https://github.com/jnoecker/mowen"
|
|
28
|
+
Repository = "https://github.com/jnoecker/mowen"
|
|
29
|
+
|
|
30
|
+
[project.scripts]
|
|
31
|
+
mowen = "mowen_cli.main:app"
|
|
32
|
+
|
|
33
|
+
[tool.hatch.build.targets.wheel]
|
|
34
|
+
packages = ["src/mowen_cli"]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""mowen-cli — command-line interface for mowen."""
|
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
"""mowen CLI — authorship attribution from the command line."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Annotated, Optional
|
|
9
|
+
|
|
10
|
+
import typer
|
|
11
|
+
|
|
12
|
+
app = typer.Typer(
|
|
13
|
+
name="mowen",
|
|
14
|
+
help="Authorship attribution toolkit.",
|
|
15
|
+
no_args_is_help=True,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _parse_param(raw: str) -> tuple[str, str, dict[str, str]]:
|
|
20
|
+
"""Parse 'name:key=val,key=val' into (name, {params}).
|
|
21
|
+
|
|
22
|
+
Also accepts plain 'name' with no params.
|
|
23
|
+
"""
|
|
24
|
+
if ":" not in raw:
|
|
25
|
+
return raw, raw, {}
|
|
26
|
+
name, param_str = raw.split(":", 1)
|
|
27
|
+
params: dict[str, str] = {}
|
|
28
|
+
for pair in param_str.split(","):
|
|
29
|
+
if "=" not in pair:
|
|
30
|
+
continue
|
|
31
|
+
k, v = pair.split("=", 1)
|
|
32
|
+
params[k.strip()] = v.strip()
|
|
33
|
+
return name, name, params
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _spec(raw: str) -> dict:
|
|
37
|
+
"""Turn a CLI component string into a dict for PipelineConfig."""
|
|
38
|
+
name, _, params = _parse_param(raw)
|
|
39
|
+
return {"name": name, "params": params} if params else {"name": name}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _build_config(
|
|
43
|
+
event_driver: list[str],
|
|
44
|
+
distance: str,
|
|
45
|
+
analysis: str,
|
|
46
|
+
canonicizer: list[str] | None,
|
|
47
|
+
culler: list[str] | None,
|
|
48
|
+
) -> "PipelineConfig":
|
|
49
|
+
"""Build a PipelineConfig from CLI option values."""
|
|
50
|
+
from mowen.pipeline import PipelineConfig
|
|
51
|
+
|
|
52
|
+
return PipelineConfig(
|
|
53
|
+
canonicizers=[_spec(c) for c in (canonicizer or [])],
|
|
54
|
+
event_drivers=[_spec(e) for e in event_driver],
|
|
55
|
+
event_cullers=[_spec(c) for c in (culler or [])],
|
|
56
|
+
distance_function=_spec(distance),
|
|
57
|
+
analysis_method=_spec(analysis),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _make_progress_cb(output_json: bool):
|
|
62
|
+
"""Create a terminal progress-bar callback, or None if not appropriate."""
|
|
63
|
+
if not output_json and sys.stderr.isatty():
|
|
64
|
+
def on_progress(frac: float, msg: str) -> None:
|
|
65
|
+
bar_len = 30
|
|
66
|
+
filled = int(bar_len * frac)
|
|
67
|
+
bar = "█" * filled + "░" * (bar_len - filled)
|
|
68
|
+
typer.echo(f"\r {bar} {frac:5.0%} {msg:<40}", err=True, nl=False)
|
|
69
|
+
|
|
70
|
+
return on_progress
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ---------------------------------------------------------------------------
|
|
75
|
+
# mowen run
|
|
76
|
+
# ---------------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
@app.command()
|
|
79
|
+
def run(
|
|
80
|
+
documents: Annotated[
|
|
81
|
+
Path,
|
|
82
|
+
typer.Option("--documents", "-d", help="CSV manifest: filepath,author (empty author = unknown)."),
|
|
83
|
+
],
|
|
84
|
+
event_driver: Annotated[
|
|
85
|
+
list[str],
|
|
86
|
+
typer.Option("--event-driver", "-e", help="Event driver (name or name:param=val,...). Repeatable."),
|
|
87
|
+
],
|
|
88
|
+
distance: Annotated[
|
|
89
|
+
str,
|
|
90
|
+
typer.Option("--distance", help="Distance function name."),
|
|
91
|
+
] = "cosine",
|
|
92
|
+
analysis: Annotated[
|
|
93
|
+
str,
|
|
94
|
+
typer.Option("--analysis", "-a", help="Analysis method (name or name:param=val,...)."),
|
|
95
|
+
] = "nearest_neighbor",
|
|
96
|
+
canonicizer: Annotated[
|
|
97
|
+
Optional[list[str]],
|
|
98
|
+
typer.Option("--canonicizer", "-c", help="Canonicizer (name or name:param=val,...). Repeatable."),
|
|
99
|
+
] = None,
|
|
100
|
+
culler: Annotated[
|
|
101
|
+
Optional[list[str]],
|
|
102
|
+
typer.Option("--culler", help="Event culler (name or name:param=val,...). Repeatable."),
|
|
103
|
+
] = None,
|
|
104
|
+
output_json: Annotated[
|
|
105
|
+
bool,
|
|
106
|
+
typer.Option("--json", help="Output results as JSON."),
|
|
107
|
+
] = False,
|
|
108
|
+
base_dir: Annotated[
|
|
109
|
+
Optional[Path],
|
|
110
|
+
typer.Option("--base-dir", help="Base directory for resolving relative paths in CSV."),
|
|
111
|
+
] = None,
|
|
112
|
+
) -> None:
|
|
113
|
+
"""Run an authorship attribution experiment."""
|
|
114
|
+
from mowen.compat.jgaap_csv import load_jgaap_csv
|
|
115
|
+
from mowen.exceptions import MowenError
|
|
116
|
+
from mowen.pipeline import Pipeline, PipelineConfig
|
|
117
|
+
|
|
118
|
+
# Load documents
|
|
119
|
+
try:
|
|
120
|
+
known, unknown = load_jgaap_csv(documents, base_dir=base_dir)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
typer.echo(f"Error loading documents: {e}", err=True)
|
|
123
|
+
raise typer.Exit(1)
|
|
124
|
+
|
|
125
|
+
if not known:
|
|
126
|
+
typer.echo("Error: no known (authored) documents found in CSV.", err=True)
|
|
127
|
+
raise typer.Exit(1)
|
|
128
|
+
if not unknown:
|
|
129
|
+
typer.echo("Error: no unknown documents found in CSV. Leave the author column empty for unknowns.", err=True)
|
|
130
|
+
raise typer.Exit(1)
|
|
131
|
+
|
|
132
|
+
# Build pipeline config
|
|
133
|
+
config = _build_config(event_driver, distance, analysis, canonicizer, culler)
|
|
134
|
+
|
|
135
|
+
progress_cb = _make_progress_cb(output_json)
|
|
136
|
+
|
|
137
|
+
# Execute
|
|
138
|
+
try:
|
|
139
|
+
results = Pipeline(config, progress_callback=progress_cb).execute(known, unknown)
|
|
140
|
+
except MowenError as e:
|
|
141
|
+
typer.echo(f"\nError: {e}", err=True)
|
|
142
|
+
raise typer.Exit(1)
|
|
143
|
+
|
|
144
|
+
if progress_cb:
|
|
145
|
+
typer.echo("", err=True) # newline after progress bar
|
|
146
|
+
|
|
147
|
+
# Output
|
|
148
|
+
if output_json:
|
|
149
|
+
out = [
|
|
150
|
+
{
|
|
151
|
+
"document": r.unknown_document.title,
|
|
152
|
+
"rankings": [{"author": a.author, "score": a.score} for a in r.rankings],
|
|
153
|
+
}
|
|
154
|
+
for r in results
|
|
155
|
+
]
|
|
156
|
+
typer.echo(json.dumps(out, indent=2))
|
|
157
|
+
else:
|
|
158
|
+
for r in results:
|
|
159
|
+
typer.echo(f"\n {r.unknown_document.title}")
|
|
160
|
+
typer.echo(f" {'─' * 40}")
|
|
161
|
+
for i, a in enumerate(r.rankings):
|
|
162
|
+
marker = " → " if i == 0 else " "
|
|
163
|
+
typer.echo(f" {marker}{a.author:<25} {a.score:.4f}")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# ---------------------------------------------------------------------------
|
|
167
|
+
# mowen list-components
|
|
168
|
+
# ---------------------------------------------------------------------------
|
|
169
|
+
|
|
170
|
+
@app.command("list-components")
|
|
171
|
+
def list_components(
|
|
172
|
+
category: Annotated[
|
|
173
|
+
Optional[str],
|
|
174
|
+
typer.Argument(help="Filter by category: canonicizers, event-drivers, event-cullers, distance-functions, analysis-methods."),
|
|
175
|
+
] = None,
|
|
176
|
+
output_json: Annotated[
|
|
177
|
+
bool,
|
|
178
|
+
typer.Option("--json", help="Output as JSON."),
|
|
179
|
+
] = False,
|
|
180
|
+
) -> None:
|
|
181
|
+
"""List available pipeline components and their parameters."""
|
|
182
|
+
from mowen.analysis_methods import analysis_method_registry
|
|
183
|
+
from mowen.canonicizers import canonicizer_registry
|
|
184
|
+
from mowen.distance_functions import distance_function_registry
|
|
185
|
+
from mowen.event_cullers import event_culler_registry
|
|
186
|
+
from mowen.event_drivers import event_driver_registry
|
|
187
|
+
|
|
188
|
+
registries = {
|
|
189
|
+
"canonicizers": canonicizer_registry,
|
|
190
|
+
"event-drivers": event_driver_registry,
|
|
191
|
+
"event-cullers": event_culler_registry,
|
|
192
|
+
"distance-functions": distance_function_registry,
|
|
193
|
+
"analysis-methods": analysis_method_registry,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
if category and category not in registries:
|
|
197
|
+
typer.echo(f"Unknown category: {category!r}. Choose from: {', '.join(registries)}", err=True)
|
|
198
|
+
raise typer.Exit(1)
|
|
199
|
+
|
|
200
|
+
selected = {category: registries[category]} if category else registries
|
|
201
|
+
|
|
202
|
+
if output_json:
|
|
203
|
+
out: dict = {
|
|
204
|
+
cat_name: registry.describe_components()
|
|
205
|
+
for cat_name, registry in selected.items()
|
|
206
|
+
}
|
|
207
|
+
typer.echo(json.dumps(out, indent=2))
|
|
208
|
+
else:
|
|
209
|
+
for cat_name, registry in selected.items():
|
|
210
|
+
typer.echo(f"\n {cat_name}")
|
|
211
|
+
typer.echo(f" {'═' * 50}")
|
|
212
|
+
for comp in registry.describe_components():
|
|
213
|
+
typer.echo(f" {comp['name']:<30} {comp['display_name']}")
|
|
214
|
+
if comp["description"]:
|
|
215
|
+
typer.echo(f" {comp['description']}")
|
|
216
|
+
for p in comp.get("params", []):
|
|
217
|
+
constraint = ""
|
|
218
|
+
if p["choices"]:
|
|
219
|
+
constraint = f" choices={p['choices']}"
|
|
220
|
+
elif p["min_value"] is not None or p["max_value"] is not None:
|
|
221
|
+
lo = p["min_value"] if p["min_value"] is not None else ""
|
|
222
|
+
hi = p["max_value"] if p["max_value"] is not None else ""
|
|
223
|
+
constraint = f" range=[{lo}, {hi}]"
|
|
224
|
+
typer.echo(
|
|
225
|
+
f" --{p['name']} ({p['type']}, default={p['default']}){constraint}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# ---------------------------------------------------------------------------
|
|
230
|
+
# mowen convert-jgaap
|
|
231
|
+
# ---------------------------------------------------------------------------
|
|
232
|
+
|
|
233
|
+
@app.command("convert-jgaap")
|
|
234
|
+
def convert_jgaap(
|
|
235
|
+
csv_file: Annotated[
|
|
236
|
+
Path,
|
|
237
|
+
typer.Argument(help="Path to JGAAP experiment CSV file."),
|
|
238
|
+
],
|
|
239
|
+
base_dir: Annotated[
|
|
240
|
+
Optional[Path],
|
|
241
|
+
typer.Option("--base-dir", help="Base directory for resolving relative paths."),
|
|
242
|
+
] = None,
|
|
243
|
+
) -> None:
|
|
244
|
+
"""Convert a JGAAP CSV into a summary of loaded documents."""
|
|
245
|
+
from mowen.compat.jgaap_csv import load_jgaap_csv
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
known, unknown = load_jgaap_csv(csv_file, base_dir=base_dir)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
typer.echo(f"Error: {e}", err=True)
|
|
251
|
+
raise typer.Exit(1)
|
|
252
|
+
|
|
253
|
+
typer.echo(f"\n Loaded {len(known)} known + {len(unknown)} unknown documents\n")
|
|
254
|
+
|
|
255
|
+
if known:
|
|
256
|
+
typer.echo(" Known documents:")
|
|
257
|
+
authors: dict[str, int] = {}
|
|
258
|
+
for doc in known:
|
|
259
|
+
authors[doc.author or "?"] = authors.get(doc.author or "?", 0) + 1
|
|
260
|
+
for author, count in sorted(authors.items()):
|
|
261
|
+
typer.echo(f" {author}: {count} document{'s' if count != 1 else ''}")
|
|
262
|
+
|
|
263
|
+
if unknown:
|
|
264
|
+
typer.echo(f"\n Unknown documents:")
|
|
265
|
+
for doc in unknown:
|
|
266
|
+
preview = doc.text[:60].replace("\n", " ")
|
|
267
|
+
typer.echo(f" {doc.title}: {preview}...")
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
# ---------------------------------------------------------------------------
|
|
271
|
+
# mowen evaluate
|
|
272
|
+
# ---------------------------------------------------------------------------
|
|
273
|
+
|
|
274
|
+
@app.command()
|
|
275
|
+
def evaluate(
|
|
276
|
+
documents: Annotated[
|
|
277
|
+
Path,
|
|
278
|
+
typer.Option("--documents", "-d", help="CSV manifest: filepath,author. All rows must have authors."),
|
|
279
|
+
],
|
|
280
|
+
event_driver: Annotated[
|
|
281
|
+
list[str],
|
|
282
|
+
typer.Option("--event-driver", "-e", help="Event driver (name or name:param=val,...). Repeatable."),
|
|
283
|
+
],
|
|
284
|
+
distance: Annotated[
|
|
285
|
+
str,
|
|
286
|
+
typer.Option("--distance", help="Distance function name."),
|
|
287
|
+
] = "cosine",
|
|
288
|
+
analysis: Annotated[
|
|
289
|
+
str,
|
|
290
|
+
typer.Option("--analysis", "-a", help="Analysis method (name or name:param=val,...)."),
|
|
291
|
+
] = "nearest_neighbor",
|
|
292
|
+
canonicizer: Annotated[
|
|
293
|
+
Optional[list[str]],
|
|
294
|
+
typer.Option("--canonicizer", "-c", help="Canonicizer. Repeatable."),
|
|
295
|
+
] = None,
|
|
296
|
+
culler: Annotated[
|
|
297
|
+
Optional[list[str]],
|
|
298
|
+
typer.Option("--culler", help="Event culler. Repeatable."),
|
|
299
|
+
] = None,
|
|
300
|
+
mode: Annotated[
|
|
301
|
+
str,
|
|
302
|
+
typer.Option("--mode", "-m", help="Evaluation mode: loo or kfold."),
|
|
303
|
+
] = "loo",
|
|
304
|
+
folds: Annotated[
|
|
305
|
+
int,
|
|
306
|
+
typer.Option("--folds", "-k", help="Number of folds for kfold mode."),
|
|
307
|
+
] = 10,
|
|
308
|
+
seed: Annotated[
|
|
309
|
+
Optional[int],
|
|
310
|
+
typer.Option("--seed", help="Random seed for kfold shuffle."),
|
|
311
|
+
] = None,
|
|
312
|
+
output_csv: Annotated[
|
|
313
|
+
Optional[Path],
|
|
314
|
+
typer.Option("--output-csv", "-o", help="Write results to CSV file."),
|
|
315
|
+
] = None,
|
|
316
|
+
output_json: Annotated[
|
|
317
|
+
bool,
|
|
318
|
+
typer.Option("--json", help="Output results as JSON."),
|
|
319
|
+
] = False,
|
|
320
|
+
base_dir: Annotated[
|
|
321
|
+
Optional[Path],
|
|
322
|
+
typer.Option("--base-dir", help="Base directory for resolving relative paths in CSV."),
|
|
323
|
+
] = None,
|
|
324
|
+
) -> None:
|
|
325
|
+
"""Evaluate pipeline accuracy via cross-validation."""
|
|
326
|
+
from mowen.compat.jgaap_csv import load_jgaap_csv
|
|
327
|
+
from mowen.evaluation import leave_one_out as loo_eval, k_fold as kfold_eval, write_results_csv
|
|
328
|
+
from mowen.exceptions import EvaluationError, MowenError
|
|
329
|
+
|
|
330
|
+
# Load documents
|
|
331
|
+
try:
|
|
332
|
+
known, unknown = load_jgaap_csv(documents, base_dir=base_dir)
|
|
333
|
+
except Exception as e:
|
|
334
|
+
typer.echo(f"Error loading documents: {e}", err=True)
|
|
335
|
+
raise typer.Exit(1)
|
|
336
|
+
|
|
337
|
+
if unknown:
|
|
338
|
+
typer.echo(
|
|
339
|
+
f" Note: {len(unknown)} unknown document(s) ignored in evaluation mode.",
|
|
340
|
+
err=True,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if not known:
|
|
344
|
+
typer.echo("Error: no known (authored) documents found in CSV.", err=True)
|
|
345
|
+
raise typer.Exit(1)
|
|
346
|
+
|
|
347
|
+
config = _build_config(event_driver, distance, analysis, canonicizer, culler)
|
|
348
|
+
|
|
349
|
+
progress_cb = _make_progress_cb(output_json)
|
|
350
|
+
|
|
351
|
+
# Run evaluation
|
|
352
|
+
try:
|
|
353
|
+
if mode == "loo":
|
|
354
|
+
result = loo_eval(known, config, progress_callback=progress_cb)
|
|
355
|
+
elif mode == "kfold":
|
|
356
|
+
result = kfold_eval(
|
|
357
|
+
known, config, k=folds, random_seed=seed,
|
|
358
|
+
progress_callback=progress_cb,
|
|
359
|
+
)
|
|
360
|
+
else:
|
|
361
|
+
typer.echo(f"Unknown mode: {mode!r}. Choose 'loo' or 'kfold'.", err=True)
|
|
362
|
+
raise typer.Exit(1)
|
|
363
|
+
except (EvaluationError, MowenError) as e:
|
|
364
|
+
typer.echo(f"\nError: {e}", err=True)
|
|
365
|
+
raise typer.Exit(1)
|
|
366
|
+
|
|
367
|
+
if progress_cb:
|
|
368
|
+
typer.echo("", err=True)
|
|
369
|
+
|
|
370
|
+
# CSV export
|
|
371
|
+
if output_csv:
|
|
372
|
+
write_results_csv(result, output_csv)
|
|
373
|
+
typer.echo(f" Results written to {output_csv}", err=True)
|
|
374
|
+
|
|
375
|
+
# Output
|
|
376
|
+
if output_json:
|
|
377
|
+
out = {
|
|
378
|
+
"accuracy": result.accuracy,
|
|
379
|
+
"macro_precision": result.macro_precision,
|
|
380
|
+
"macro_recall": result.macro_recall,
|
|
381
|
+
"macro_f1": result.macro_f1,
|
|
382
|
+
"per_author": [
|
|
383
|
+
{"author": a.author, "precision": a.precision,
|
|
384
|
+
"recall": a.recall, "f1": a.f1, "support": a.support}
|
|
385
|
+
for a in result.per_author
|
|
386
|
+
],
|
|
387
|
+
"confusion_matrix": result.confusion_matrix,
|
|
388
|
+
"predictions": [
|
|
389
|
+
{"fold": fr.fold_index, "document": p.document_title,
|
|
390
|
+
"true_author": p.true_author, "predicted_author": p.predicted_author}
|
|
391
|
+
for fr in result.fold_results for p in fr.predictions
|
|
392
|
+
],
|
|
393
|
+
}
|
|
394
|
+
typer.echo(json.dumps(out, indent=2))
|
|
395
|
+
else:
|
|
396
|
+
n_docs = sum(fr.total for fr in result.fold_results)
|
|
397
|
+
n_correct = sum(fr.correct for fr in result.fold_results)
|
|
398
|
+
n_authors = len(result.per_author)
|
|
399
|
+
mode_label = "leave-one-out" if mode == "loo" else f"{folds}-fold"
|
|
400
|
+
|
|
401
|
+
typer.echo(f"\n Cross-validation: {mode_label} ({n_docs} documents, {n_authors} authors)")
|
|
402
|
+
typer.echo(f" {'═' * 56}")
|
|
403
|
+
typer.echo(f"\n Accuracy: {result.accuracy:.1%} ({n_correct}/{n_docs})")
|
|
404
|
+
|
|
405
|
+
typer.echo(f"\n Per-author metrics:")
|
|
406
|
+
typer.echo(f" {'Author':<20} {'Precision':>9} {'Recall':>9} {'F1':>9} {'Support':>8}")
|
|
407
|
+
typer.echo(f" {'─' * 20} {'─' * 9} {'─' * 9} {'─' * 9} {'─' * 8}")
|
|
408
|
+
for a in result.per_author:
|
|
409
|
+
typer.echo(
|
|
410
|
+
f" {a.author:<20} {a.precision:>9.4f} {a.recall:>9.4f} "
|
|
411
|
+
f"{a.f1:>9.4f} {a.support:>8}"
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
typer.echo(
|
|
415
|
+
f"\n Macro avg: P={result.macro_precision:.4f} "
|
|
416
|
+
f"R={result.macro_recall:.4f} F1={result.macro_f1:.4f}"
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Confusion matrix
|
|
420
|
+
authors = sorted(result.confusion_matrix.keys())
|
|
421
|
+
col_w = max(len(a) for a in authors) + 2
|
|
422
|
+
col_w = max(col_w, 6)
|
|
423
|
+
typer.echo(f"\n Confusion matrix:")
|
|
424
|
+
header = " " + " " * col_w + "".join(f"{a:>{col_w}}" for a in authors)
|
|
425
|
+
typer.echo(header)
|
|
426
|
+
for true_a in authors:
|
|
427
|
+
row = result.confusion_matrix[true_a]
|
|
428
|
+
cells = "".join(f"{row.get(a, 0):>{col_w}}" for a in authors)
|
|
429
|
+
typer.echo(f" {true_a:<{col_w}}{cells}")
|