deepresearch-flow 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepresearch_flow/paper/db.py +184 -0
- deepresearch_flow/paper/db_ops.py +1939 -0
- deepresearch_flow/paper/web/app.py +38 -3705
- deepresearch_flow/paper/web/constants.py +23 -0
- deepresearch_flow/paper/web/filters.py +255 -0
- deepresearch_flow/paper/web/handlers/__init__.py +14 -0
- deepresearch_flow/paper/web/handlers/api.py +217 -0
- deepresearch_flow/paper/web/handlers/pages.py +334 -0
- deepresearch_flow/paper/web/markdown.py +549 -0
- deepresearch_flow/paper/web/static/css/main.css +857 -0
- deepresearch_flow/paper/web/static/js/detail.js +406 -0
- deepresearch_flow/paper/web/static/js/index.js +266 -0
- deepresearch_flow/paper/web/static/js/outline.js +58 -0
- deepresearch_flow/paper/web/static/js/stats.js +39 -0
- deepresearch_flow/paper/web/templates/base.html +43 -0
- deepresearch_flow/paper/web/templates/detail.html +332 -0
- deepresearch_flow/paper/web/templates/index.html +114 -0
- deepresearch_flow/paper/web/templates/stats.html +29 -0
- deepresearch_flow/paper/web/templates.py +85 -0
- deepresearch_flow/paper/web/text.py +68 -0
- deepresearch_flow/recognize/cli.py +805 -26
- deepresearch_flow/recognize/katex_check.js +29 -0
- deepresearch_flow/recognize/math.py +719 -0
- deepresearch_flow/recognize/mermaid.py +690 -0
- {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/METADATA +78 -4
- {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/RECORD +30 -9
- {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/WHEEL +0 -0
- {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/entry_points.txt +0 -0
- {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -3,10 +3,11 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
|
+
import json
|
|
6
7
|
import logging
|
|
7
8
|
import time
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Awaitable, Callable, Iterable
|
|
10
|
+
from typing import Any, Awaitable, Callable, Iterable
|
|
10
11
|
|
|
11
12
|
import click
|
|
12
13
|
import coloredlogs
|
|
@@ -15,6 +16,9 @@ from rich.console import Console
|
|
|
15
16
|
from rich.table import Table
|
|
16
17
|
from tqdm import tqdm
|
|
17
18
|
|
|
19
|
+
from deepresearch_flow.paper.config import load_config, resolve_api_keys
|
|
20
|
+
from deepresearch_flow.paper.extract import parse_model_ref
|
|
21
|
+
from deepresearch_flow.paper.template_registry import get_stage_definitions
|
|
18
22
|
from deepresearch_flow.paper.utils import discover_markdown
|
|
19
23
|
from deepresearch_flow.recognize.markdown import (
|
|
20
24
|
DEFAULT_USER_AGENT,
|
|
@@ -26,6 +30,19 @@ from deepresearch_flow.recognize.markdown import (
|
|
|
26
30
|
sanitize_filename,
|
|
27
31
|
unpack_markdown_images,
|
|
28
32
|
)
|
|
33
|
+
from deepresearch_flow.recognize.math import (
|
|
34
|
+
MathFixStats,
|
|
35
|
+
extract_math_spans,
|
|
36
|
+
fix_math_text,
|
|
37
|
+
locate_json_field_start,
|
|
38
|
+
require_pylatexenc,
|
|
39
|
+
)
|
|
40
|
+
from deepresearch_flow.recognize.mermaid import (
|
|
41
|
+
MermaidFixStats,
|
|
42
|
+
extract_mermaid_spans,
|
|
43
|
+
fix_mermaid_text,
|
|
44
|
+
require_mmdc,
|
|
45
|
+
)
|
|
29
46
|
from deepresearch_flow.recognize.organize import (
|
|
30
47
|
discover_mineru_dirs,
|
|
31
48
|
fix_markdown_text,
|
|
@@ -72,23 +89,28 @@ def _unique_output_filename(
|
|
|
72
89
|
base: str,
|
|
73
90
|
output_dirs: Iterable[Path],
|
|
74
91
|
used: set[str],
|
|
92
|
+
ext: str,
|
|
75
93
|
) -> str:
|
|
76
94
|
base = sanitize_filename(base) or "document"
|
|
77
|
-
candidate = f"{base}
|
|
95
|
+
candidate = f"{base}{ext}"
|
|
78
96
|
counter = 0
|
|
79
97
|
while candidate in used or any((directory / candidate).exists() for directory in output_dirs):
|
|
80
98
|
counter += 1
|
|
81
|
-
candidate = f"{base}_{counter}
|
|
99
|
+
candidate = f"{base}_{counter}{ext}"
|
|
82
100
|
used.add(candidate)
|
|
83
101
|
return candidate
|
|
84
102
|
|
|
85
103
|
|
|
86
|
-
def _map_output_files(
|
|
104
|
+
def _map_output_files(
|
|
105
|
+
paths: Iterable[Path],
|
|
106
|
+
output_dirs: list[Path],
|
|
107
|
+
ext: str = ".md",
|
|
108
|
+
) -> dict[Path, str]:
|
|
87
109
|
used: set[str] = set()
|
|
88
110
|
mapping: dict[Path, str] = {}
|
|
89
111
|
for path in paths:
|
|
90
112
|
base = path.stem
|
|
91
|
-
mapping[path] = _unique_output_filename(base, output_dirs, used)
|
|
113
|
+
mapping[path] = _unique_output_filename(base, output_dirs, used, ext)
|
|
92
114
|
return mapping
|
|
93
115
|
|
|
94
116
|
|
|
@@ -112,6 +134,93 @@ def _format_duration(seconds: float) -> str:
|
|
|
112
134
|
return f"{int(hours)}h {int(minutes)}m {remainder:.1f}s"
|
|
113
135
|
|
|
114
136
|
|
|
137
|
+
def _resolve_item_template(item: dict[str, Any], default_template: str | None) -> str | None:
|
|
138
|
+
raw = item.get("template_tag") or item.get("prompt_template") or default_template
|
|
139
|
+
if isinstance(raw, str) and raw:
|
|
140
|
+
return raw
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _template_markdown_fields(template: str | None) -> list[str]:
|
|
145
|
+
if template:
|
|
146
|
+
stages = get_stage_definitions(template)
|
|
147
|
+
if stages:
|
|
148
|
+
return [field for stage in stages for field in stage.fields]
|
|
149
|
+
return ["summary", "abstract"]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def discover_json(inputs: Iterable[str], recursive: bool) -> list[Path]:
|
|
153
|
+
files: set[Path] = set()
|
|
154
|
+
for raw in inputs:
|
|
155
|
+
path = Path(raw)
|
|
156
|
+
if path.is_file():
|
|
157
|
+
if path.suffix.lower() != ".json":
|
|
158
|
+
raise ValueError(f"Input file is not a json file: {path}")
|
|
159
|
+
files.add(path.resolve())
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
if path.is_dir():
|
|
163
|
+
pattern = path.rglob("*.json") if recursive else path.glob("*.json")
|
|
164
|
+
for match in pattern:
|
|
165
|
+
if match.is_file():
|
|
166
|
+
files.add(match.resolve())
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
raise FileNotFoundError(f"Input path not found: {path}")
|
|
170
|
+
|
|
171
|
+
return sorted(files)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _load_json_payload(path: Path) -> tuple[list[Any], dict[str, Any] | None, str | None]:
|
|
175
|
+
try:
|
|
176
|
+
data = json.loads(read_text(path))
|
|
177
|
+
except json.JSONDecodeError as exc:
|
|
178
|
+
raise click.ClickException(f"Invalid JSON in {path}: {exc}") from exc
|
|
179
|
+
|
|
180
|
+
if isinstance(data, list):
|
|
181
|
+
return data, None, None
|
|
182
|
+
if isinstance(data, dict):
|
|
183
|
+
papers = data.get("papers")
|
|
184
|
+
if isinstance(papers, list):
|
|
185
|
+
template_tag = data.get("template_tag")
|
|
186
|
+
return papers, data, template_tag if isinstance(template_tag, str) else None
|
|
187
|
+
raise click.ClickException(f"JSON object missing 'papers' list: {path}")
|
|
188
|
+
|
|
189
|
+
raise click.ClickException(f"Unsupported JSON structure in {path}")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
async def _fix_json_items(
|
|
193
|
+
items: list[Any],
|
|
194
|
+
default_template: str | None,
|
|
195
|
+
fix_level: str,
|
|
196
|
+
format_enabled: bool,
|
|
197
|
+
) -> tuple[int, int, int, int]:
|
|
198
|
+
items_total = 0
|
|
199
|
+
items_updated = 0
|
|
200
|
+
fields_total = 0
|
|
201
|
+
fields_updated = 0
|
|
202
|
+
for item in items:
|
|
203
|
+
if not isinstance(item, dict):
|
|
204
|
+
continue
|
|
205
|
+
items_total += 1
|
|
206
|
+
template = _resolve_item_template(item, default_template)
|
|
207
|
+
fields = _template_markdown_fields(template)
|
|
208
|
+
item_updated = False
|
|
209
|
+
for field in fields:
|
|
210
|
+
value = item.get(field)
|
|
211
|
+
if not isinstance(value, str):
|
|
212
|
+
continue
|
|
213
|
+
fields_total += 1
|
|
214
|
+
updated = await fix_markdown_text(value, fix_level, format_enabled)
|
|
215
|
+
if updated != value:
|
|
216
|
+
item[field] = updated
|
|
217
|
+
fields_updated += 1
|
|
218
|
+
item_updated = True
|
|
219
|
+
if item_updated:
|
|
220
|
+
items_updated += 1
|
|
221
|
+
return items_total, items_updated, fields_total, fields_updated
|
|
222
|
+
|
|
223
|
+
|
|
115
224
|
async def _run_with_workers(
|
|
116
225
|
items: Iterable[Path],
|
|
117
226
|
workers: int,
|
|
@@ -226,6 +335,46 @@ async def _run_fix(
|
|
|
226
335
|
await _run_with_workers(paths, workers, handler, progress=progress)
|
|
227
336
|
|
|
228
337
|
|
|
338
|
+
async def _run_fix_json(
|
|
339
|
+
paths: list[Path],
|
|
340
|
+
output_map: dict[Path, Path],
|
|
341
|
+
fix_level: str,
|
|
342
|
+
format_enabled: bool,
|
|
343
|
+
workers: int,
|
|
344
|
+
progress: tqdm | None,
|
|
345
|
+
) -> list[tuple[int, int, int, int, int]]:
|
|
346
|
+
semaphore = asyncio.Semaphore(workers)
|
|
347
|
+
progress_lock = asyncio.Lock() if progress else None
|
|
348
|
+
results: list[tuple[int, int, int, int, int]] = []
|
|
349
|
+
|
|
350
|
+
async def handler(path: Path) -> tuple[int, int, int, int, int]:
|
|
351
|
+
items, payload, template_tag = _load_json_payload(path)
|
|
352
|
+
items_total, items_updated, fields_total, fields_updated = await _fix_json_items(
|
|
353
|
+
items, template_tag, fix_level, format_enabled
|
|
354
|
+
)
|
|
355
|
+
output_data: Any
|
|
356
|
+
if payload is None:
|
|
357
|
+
output_data = items
|
|
358
|
+
else:
|
|
359
|
+
payload["papers"] = items
|
|
360
|
+
output_data = payload
|
|
361
|
+
output_path = output_map[path]
|
|
362
|
+
serialized = json.dumps(output_data, ensure_ascii=False, indent=2)
|
|
363
|
+
await asyncio.to_thread(output_path.write_text, f"{serialized}\n", encoding="utf-8")
|
|
364
|
+
return len(items), items_total, items_updated, fields_total, fields_updated
|
|
365
|
+
|
|
366
|
+
async def runner(path: Path) -> None:
|
|
367
|
+
async with semaphore:
|
|
368
|
+
result = await handler(path)
|
|
369
|
+
results.append(result)
|
|
370
|
+
if progress and progress_lock:
|
|
371
|
+
async with progress_lock:
|
|
372
|
+
progress.update(1)
|
|
373
|
+
|
|
374
|
+
await asyncio.gather(*(runner(path) for path in paths))
|
|
375
|
+
return results
|
|
376
|
+
|
|
377
|
+
|
|
229
378
|
@click.group()
|
|
230
379
|
def recognize() -> None:
|
|
231
380
|
"""OCR recognition and Markdown post-processing commands."""
|
|
@@ -530,11 +679,12 @@ def organize(
|
|
|
530
679
|
"inputs",
|
|
531
680
|
multiple=True,
|
|
532
681
|
required=True,
|
|
533
|
-
help="Input markdown
|
|
682
|
+
help="Input markdown or JSON file/directory (repeatable)",
|
|
534
683
|
)
|
|
535
684
|
@click.option("-o", "--output", "output_dir", default=None, help="Output directory")
|
|
536
685
|
@click.option("--in-place", "in_place", is_flag=True, help="Fix markdown files in place")
|
|
537
|
-
@click.option("-r", "--recursive", is_flag=True, help="Recursively discover
|
|
686
|
+
@click.option("-r", "--recursive", is_flag=True, help="Recursively discover files")
|
|
687
|
+
@click.option("--json", "json_mode", is_flag=True, help="Fix markdown fields inside JSON outputs")
|
|
538
688
|
@click.option(
|
|
539
689
|
"--fix-level",
|
|
540
690
|
"fix_level",
|
|
@@ -552,13 +702,14 @@ def recognize_fix(
|
|
|
552
702
|
output_dir: str | None,
|
|
553
703
|
in_place: bool,
|
|
554
704
|
recursive: bool,
|
|
705
|
+
json_mode: bool,
|
|
555
706
|
fix_level: str,
|
|
556
707
|
no_format: bool,
|
|
557
708
|
workers: int,
|
|
558
709
|
dry_run: bool,
|
|
559
710
|
verbose: bool,
|
|
560
711
|
) -> None:
|
|
561
|
-
"""Fix and format OCR markdown outputs."""
|
|
712
|
+
"""Fix and format OCR markdown outputs (markdown or JSON)."""
|
|
562
713
|
configure_logging(verbose)
|
|
563
714
|
start_time = time.monotonic()
|
|
564
715
|
if workers <= 0:
|
|
@@ -573,19 +724,58 @@ def recognize_fix(
|
|
|
573
724
|
output_path = _ensure_output_dir(output_dir)
|
|
574
725
|
_warn_if_not_empty(output_path)
|
|
575
726
|
|
|
576
|
-
|
|
727
|
+
if json_mode:
|
|
728
|
+
paths = discover_json(inputs, recursive=recursive)
|
|
729
|
+
else:
|
|
730
|
+
json_inputs: list[str] = []
|
|
731
|
+
md_inputs: list[str] = []
|
|
732
|
+
for raw in inputs:
|
|
733
|
+
path = Path(raw)
|
|
734
|
+
if path.is_file():
|
|
735
|
+
suffix = path.suffix.lower()
|
|
736
|
+
if suffix == ".json":
|
|
737
|
+
json_inputs.append(raw)
|
|
738
|
+
continue
|
|
739
|
+
if suffix == ".md":
|
|
740
|
+
md_inputs.append(raw)
|
|
741
|
+
continue
|
|
742
|
+
raise click.ClickException(f"Input file must be .md or .json: {path}")
|
|
743
|
+
if path.is_dir():
|
|
744
|
+
json_inputs.append(raw)
|
|
745
|
+
md_inputs.append(raw)
|
|
746
|
+
continue
|
|
747
|
+
raise click.ClickException(f"Input path not found: {path}")
|
|
748
|
+
json_paths = discover_json(json_inputs, recursive=recursive) if json_inputs else []
|
|
749
|
+
md_paths = discover_markdown(md_inputs, None, recursive=recursive) if md_inputs else []
|
|
750
|
+
if json_paths and not md_paths:
|
|
751
|
+
json_mode = True
|
|
752
|
+
paths = json_paths
|
|
753
|
+
click.echo("Detected JSON inputs; enabling --json mode")
|
|
754
|
+
elif md_paths and not json_paths:
|
|
755
|
+
paths = md_paths
|
|
756
|
+
elif json_paths and md_paths:
|
|
757
|
+
raise click.ClickException(
|
|
758
|
+
"Found both markdown and JSON inputs; split inputs or pass --json explicitly"
|
|
759
|
+
)
|
|
760
|
+
else:
|
|
761
|
+
paths = []
|
|
577
762
|
if not paths:
|
|
578
|
-
click.echo("No
|
|
763
|
+
click.echo("No files discovered")
|
|
579
764
|
return
|
|
580
765
|
|
|
581
766
|
format_enabled = not no_format
|
|
582
767
|
if in_place:
|
|
583
768
|
output_map = {path: path for path in paths}
|
|
584
769
|
else:
|
|
585
|
-
|
|
770
|
+
ext = ".json" if json_mode else ".md"
|
|
771
|
+
output_map = {
|
|
772
|
+
path: (output_path / name)
|
|
773
|
+
for path, name in _map_output_files(paths, [output_path], ext=ext).items()
|
|
774
|
+
}
|
|
586
775
|
|
|
587
776
|
if dry_run:
|
|
588
777
|
rows = [
|
|
778
|
+
("Mode", "json" if json_mode else "markdown"),
|
|
589
779
|
("Inputs", str(len(paths))),
|
|
590
780
|
("Outputs", str(len(output_map))),
|
|
591
781
|
("Fix level", fix_level),
|
|
@@ -599,25 +789,614 @@ def recognize_fix(
|
|
|
599
789
|
|
|
600
790
|
progress = tqdm(total=len(paths), desc="fix", unit="file")
|
|
601
791
|
try:
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
792
|
+
if json_mode:
|
|
793
|
+
results = asyncio.run(
|
|
794
|
+
_run_fix_json(
|
|
795
|
+
paths,
|
|
796
|
+
output_map,
|
|
797
|
+
fix_level,
|
|
798
|
+
format_enabled,
|
|
799
|
+
workers,
|
|
800
|
+
progress,
|
|
801
|
+
)
|
|
802
|
+
)
|
|
803
|
+
else:
|
|
804
|
+
asyncio.run(
|
|
805
|
+
_run_fix(
|
|
806
|
+
paths,
|
|
807
|
+
output_map,
|
|
808
|
+
fix_level,
|
|
809
|
+
format_enabled,
|
|
810
|
+
workers,
|
|
811
|
+
progress,
|
|
812
|
+
)
|
|
813
|
+
)
|
|
814
|
+
finally:
|
|
815
|
+
progress.close()
|
|
816
|
+
if json_mode:
|
|
817
|
+
total_items = sum(result[0] for result in results)
|
|
818
|
+
items_processed = sum(result[1] for result in results)
|
|
819
|
+
items_updated = sum(result[2] for result in results)
|
|
820
|
+
fields_total = sum(result[3] for result in results)
|
|
821
|
+
fields_updated = sum(result[4] for result in results)
|
|
822
|
+
items_skipped = total_items - items_processed
|
|
823
|
+
rows = [
|
|
824
|
+
("Mode", "json"),
|
|
825
|
+
("Inputs", str(len(paths))),
|
|
826
|
+
("Outputs", str(len(output_map))),
|
|
827
|
+
("Items", str(total_items)),
|
|
828
|
+
("Items processed", str(items_processed)),
|
|
829
|
+
("Items skipped", str(items_skipped)),
|
|
830
|
+
("Items updated", str(items_updated)),
|
|
831
|
+
("Fields processed", str(fields_total)),
|
|
832
|
+
("Fields updated", str(fields_updated)),
|
|
833
|
+
("Fix level", fix_level),
|
|
834
|
+
("Format", "no" if no_format else "yes"),
|
|
835
|
+
("In place", "yes" if in_place else "no"),
|
|
836
|
+
("Output dir", _relative_path(output_path) if output_path else "-"),
|
|
837
|
+
("Duration", _format_duration(time.monotonic() - start_time)),
|
|
838
|
+
]
|
|
839
|
+
else:
|
|
840
|
+
rows = [
|
|
841
|
+
("Mode", "markdown"),
|
|
842
|
+
("Inputs", str(len(paths))),
|
|
843
|
+
("Outputs", str(len(output_map))),
|
|
844
|
+
("Fix level", fix_level),
|
|
845
|
+
("Format", "no" if no_format else "yes"),
|
|
846
|
+
("In place", "yes" if in_place else "no"),
|
|
847
|
+
("Output dir", _relative_path(output_path) if output_path else "-"),
|
|
848
|
+
("Duration", _format_duration(time.monotonic() - start_time)),
|
|
849
|
+
]
|
|
850
|
+
_print_summary("recognize fix", rows)
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
@recognize.command("fix-math")
|
|
854
|
+
@click.option("-c", "--config", "config_path", default="config.toml", help="Path to config.toml")
|
|
855
|
+
@click.option(
|
|
856
|
+
"-i",
|
|
857
|
+
"--input",
|
|
858
|
+
"inputs",
|
|
859
|
+
multiple=True,
|
|
860
|
+
required=True,
|
|
861
|
+
help="Input markdown or JSON file/directory (repeatable)",
|
|
862
|
+
)
|
|
863
|
+
@click.option("-o", "--output", "output_dir", default=None, help="Output directory")
|
|
864
|
+
@click.option("--in-place", "in_place", is_flag=True, help="Fix formulas in place")
|
|
865
|
+
@click.option("-r", "--recursive", is_flag=True, help="Recursively discover files")
|
|
866
|
+
@click.option("--json", "json_mode", is_flag=True, help="Process JSON inputs instead of markdown")
|
|
867
|
+
@click.option("-m", "--model", "model_ref", required=True, help="provider/model")
|
|
868
|
+
@click.option("--batch-size", "batch_size", default=10, show_default=True, type=int)
|
|
869
|
+
@click.option("--context-chars", "context_chars", default=80, show_default=True, type=int)
|
|
870
|
+
@click.option("--max-retries", "max_retries", default=3, show_default=True, type=int)
|
|
871
|
+
@click.option("--workers", type=int, default=4, show_default=True, help="Concurrent workers")
|
|
872
|
+
@click.option("--timeout", "timeout", default=120.0, show_default=True, type=float)
|
|
873
|
+
@click.option(
|
|
874
|
+
"--only-show-error",
|
|
875
|
+
"only_show_error",
|
|
876
|
+
is_flag=True,
|
|
877
|
+
help="Only validate formulas and report error counts",
|
|
878
|
+
)
|
|
879
|
+
@click.option("--report", "report_path", default=None, help="Error report output path")
|
|
880
|
+
@click.option("--dry-run", is_flag=True, help="Report actions without writing files")
|
|
881
|
+
@click.option("-v", "--verbose", is_flag=True, help="Enable verbose logging")
|
|
882
|
+
def recognize_fix_math(
|
|
883
|
+
config_path: str,
|
|
884
|
+
inputs: tuple[str, ...],
|
|
885
|
+
output_dir: str | None,
|
|
886
|
+
in_place: bool,
|
|
887
|
+
recursive: bool,
|
|
888
|
+
json_mode: bool,
|
|
889
|
+
model_ref: str,
|
|
890
|
+
batch_size: int,
|
|
891
|
+
context_chars: int,
|
|
892
|
+
max_retries: int,
|
|
893
|
+
workers: int,
|
|
894
|
+
timeout: float,
|
|
895
|
+
only_show_error: bool,
|
|
896
|
+
report_path: str | None,
|
|
897
|
+
dry_run: bool,
|
|
898
|
+
verbose: bool,
|
|
899
|
+
) -> None:
|
|
900
|
+
"""Validate and repair LaTeX formulas in markdown or JSON outputs."""
|
|
901
|
+
configure_logging(verbose)
|
|
902
|
+
if in_place and output_dir:
|
|
903
|
+
raise click.ClickException("--in-place cannot be used with --output")
|
|
904
|
+
if not only_show_error and not in_place and not output_dir:
|
|
905
|
+
raise click.ClickException("Either --in-place or --output is required")
|
|
906
|
+
if batch_size <= 0:
|
|
907
|
+
raise click.ClickException("--batch-size must be positive")
|
|
908
|
+
if context_chars < 0:
|
|
909
|
+
raise click.ClickException("--context-chars must be non-negative")
|
|
910
|
+
if max_retries < 0:
|
|
911
|
+
raise click.ClickException("--max-retries must be non-negative")
|
|
912
|
+
if workers <= 0:
|
|
913
|
+
raise click.ClickException("--workers must be positive")
|
|
914
|
+
try:
|
|
915
|
+
require_pylatexenc()
|
|
916
|
+
except RuntimeError as exc:
|
|
917
|
+
raise click.ClickException(str(exc)) from exc
|
|
918
|
+
|
|
919
|
+
if not json_mode:
|
|
920
|
+
file_types: set[str] = set()
|
|
921
|
+
for raw in inputs:
|
|
922
|
+
path = Path(raw)
|
|
923
|
+
if path.is_file():
|
|
924
|
+
suffix = path.suffix.lower()
|
|
925
|
+
if suffix in {".md", ".json"}:
|
|
926
|
+
file_types.add(suffix)
|
|
927
|
+
if ".md" in file_types and ".json" in file_types:
|
|
928
|
+
raise click.ClickException(
|
|
929
|
+
"Mixed markdown and JSON inputs. Use --json for JSON or split commands."
|
|
610
930
|
)
|
|
931
|
+
if ".json" in file_types:
|
|
932
|
+
json_mode = True
|
|
933
|
+
logger.info("Detected JSON inputs; enabling --json mode")
|
|
934
|
+
|
|
935
|
+
config = load_config(config_path)
|
|
936
|
+
provider, model_name = parse_model_ref(model_ref, config.providers)
|
|
937
|
+
api_keys = resolve_api_keys(provider.api_keys)
|
|
938
|
+
if provider.type in {
|
|
939
|
+
"openai_compatible",
|
|
940
|
+
"dashscope",
|
|
941
|
+
"gemini_ai_studio",
|
|
942
|
+
"azure_openai",
|
|
943
|
+
"claude",
|
|
944
|
+
} and not api_keys:
|
|
945
|
+
raise click.ClickException(f"{provider.type} providers require api_keys")
|
|
946
|
+
api_key = api_keys[0] if api_keys else None
|
|
947
|
+
|
|
948
|
+
if json_mode:
|
|
949
|
+
paths = discover_json(inputs, recursive=recursive)
|
|
950
|
+
else:
|
|
951
|
+
paths = discover_markdown(inputs, None, recursive=recursive)
|
|
952
|
+
if not paths:
|
|
953
|
+
click.echo("No files discovered")
|
|
954
|
+
return
|
|
955
|
+
|
|
956
|
+
output_path = Path(output_dir) if output_dir else None
|
|
957
|
+
if output_path and not dry_run and not only_show_error:
|
|
958
|
+
output_path = _ensure_output_dir(output_dir)
|
|
959
|
+
_warn_if_not_empty(output_path)
|
|
960
|
+
|
|
961
|
+
if in_place:
|
|
962
|
+
output_map = {path: path for path in paths}
|
|
963
|
+
elif output_path:
|
|
964
|
+
ext = ".json" if json_mode else ".md"
|
|
965
|
+
output_map = {
|
|
966
|
+
path: (output_path / name)
|
|
967
|
+
for path, name in _map_output_files(paths, [output_path], ext=ext).items()
|
|
968
|
+
}
|
|
969
|
+
else:
|
|
970
|
+
output_map = {path: path for path in paths}
|
|
971
|
+
|
|
972
|
+
report_target = None
|
|
973
|
+
if report_path:
|
|
974
|
+
report_target = Path(report_path)
|
|
975
|
+
elif not only_show_error:
|
|
976
|
+
if output_path:
|
|
977
|
+
report_target = output_path / "fix-math-errors.json"
|
|
978
|
+
elif in_place:
|
|
979
|
+
report_target = Path.cwd() / "fix-math-errors.json"
|
|
980
|
+
|
|
981
|
+
if dry_run and not only_show_error:
|
|
982
|
+
rows = [
|
|
983
|
+
("Mode", "json" if json_mode else "markdown"),
|
|
984
|
+
("Inputs", str(len(paths))),
|
|
985
|
+
("Outputs", str(len(output_map))),
|
|
986
|
+
("Batch size", str(batch_size)),
|
|
987
|
+
("Context chars", str(context_chars)),
|
|
988
|
+
("Max retries", str(max_retries)),
|
|
989
|
+
("Workers", str(workers)),
|
|
990
|
+
("Timeout", f"{timeout:.1f}s"),
|
|
991
|
+
("Only show error", "yes" if only_show_error else "no"),
|
|
992
|
+
("In place", "yes" if in_place else "no"),
|
|
993
|
+
("Output dir", _relative_path(output_path) if output_path else "-"),
|
|
994
|
+
("Report", _relative_path(report_target) if report_target else "-"),
|
|
995
|
+
]
|
|
996
|
+
_print_summary("recognize fix-math (dry-run)", rows)
|
|
997
|
+
return
|
|
998
|
+
|
|
999
|
+
progress = tqdm(total=len(paths), desc="fix-math", unit="file")
|
|
1000
|
+
formula_progress = tqdm(total=0, desc="formulas", unit="formula")
|
|
1001
|
+
error_records: list[dict[str, Any]] = []
|
|
1002
|
+
|
|
1003
|
+
async def run() -> MathFixStats:
|
|
1004
|
+
semaphore = asyncio.Semaphore(workers)
|
|
1005
|
+
progress_lock = asyncio.Lock()
|
|
1006
|
+
stats_total = MathFixStats()
|
|
1007
|
+
|
|
1008
|
+
async with httpx.AsyncClient() as client:
|
|
1009
|
+
async def handle_path(path: Path) -> MathFixStats:
|
|
1010
|
+
stats = MathFixStats()
|
|
1011
|
+
if json_mode:
|
|
1012
|
+
raw_text = read_text(path)
|
|
1013
|
+
items, payload, template_tag = _load_json_payload(path)
|
|
1014
|
+
cursor = 0
|
|
1015
|
+
for item_index, item in enumerate(items):
|
|
1016
|
+
if not isinstance(item, dict):
|
|
1017
|
+
continue
|
|
1018
|
+
template = _resolve_item_template(item, template_tag)
|
|
1019
|
+
fields = _template_markdown_fields(template)
|
|
1020
|
+
for field in fields:
|
|
1021
|
+
value = item.get(field)
|
|
1022
|
+
if not isinstance(value, str):
|
|
1023
|
+
continue
|
|
1024
|
+
spans = extract_math_spans(value, context_chars)
|
|
1025
|
+
if spans:
|
|
1026
|
+
formula_progress.total += len(spans)
|
|
1027
|
+
formula_progress.refresh()
|
|
1028
|
+
line_start, cursor = locate_json_field_start(raw_text, value, cursor)
|
|
1029
|
+
field_path = f"papers[{item_index}].{field}"
|
|
1030
|
+
updated, errors = await fix_math_text(
|
|
1031
|
+
value,
|
|
1032
|
+
str(path),
|
|
1033
|
+
line_start,
|
|
1034
|
+
field_path,
|
|
1035
|
+
item_index,
|
|
1036
|
+
provider,
|
|
1037
|
+
model_name,
|
|
1038
|
+
api_key,
|
|
1039
|
+
timeout,
|
|
1040
|
+
max_retries,
|
|
1041
|
+
batch_size,
|
|
1042
|
+
context_chars,
|
|
1043
|
+
client,
|
|
1044
|
+
stats,
|
|
1045
|
+
repair_enabled=not only_show_error,
|
|
1046
|
+
spans=spans,
|
|
1047
|
+
progress_cb=lambda: formula_progress.update(1),
|
|
1048
|
+
)
|
|
1049
|
+
if not only_show_error and updated != value:
|
|
1050
|
+
item[field] = updated
|
|
1051
|
+
error_records.extend(errors)
|
|
1052
|
+
if not only_show_error:
|
|
1053
|
+
output_data: Any = items if payload is None else {**payload, "papers": items}
|
|
1054
|
+
output_path = output_map[path]
|
|
1055
|
+
serialized = json.dumps(output_data, ensure_ascii=False, indent=2)
|
|
1056
|
+
await asyncio.to_thread(output_path.write_text, f"{serialized}\n", encoding="utf-8")
|
|
1057
|
+
else:
|
|
1058
|
+
content = await asyncio.to_thread(read_text, path)
|
|
1059
|
+
spans = extract_math_spans(content, context_chars)
|
|
1060
|
+
if spans:
|
|
1061
|
+
formula_progress.total += len(spans)
|
|
1062
|
+
formula_progress.refresh()
|
|
1063
|
+
updated, errors = await fix_math_text(
|
|
1064
|
+
content,
|
|
1065
|
+
str(path),
|
|
1066
|
+
1,
|
|
1067
|
+
None,
|
|
1068
|
+
None,
|
|
1069
|
+
provider,
|
|
1070
|
+
model_name,
|
|
1071
|
+
api_key,
|
|
1072
|
+
timeout,
|
|
1073
|
+
max_retries,
|
|
1074
|
+
batch_size,
|
|
1075
|
+
context_chars,
|
|
1076
|
+
client,
|
|
1077
|
+
stats,
|
|
1078
|
+
repair_enabled=not only_show_error,
|
|
1079
|
+
spans=spans,
|
|
1080
|
+
progress_cb=lambda: formula_progress.update(1),
|
|
1081
|
+
)
|
|
1082
|
+
if not only_show_error:
|
|
1083
|
+
output_path = output_map[path]
|
|
1084
|
+
await asyncio.to_thread(output_path.write_text, updated, encoding="utf-8")
|
|
1085
|
+
error_records.extend(errors)
|
|
1086
|
+
return stats
|
|
1087
|
+
|
|
1088
|
+
async def runner(path: Path) -> None:
|
|
1089
|
+
async with semaphore:
|
|
1090
|
+
stats = await handle_path(path)
|
|
1091
|
+
stats_total.formulas_total += stats.formulas_total
|
|
1092
|
+
stats_total.formulas_invalid += stats.formulas_invalid
|
|
1093
|
+
stats_total.formulas_cleaned += stats.formulas_cleaned
|
|
1094
|
+
stats_total.formulas_repaired += stats.formulas_repaired
|
|
1095
|
+
stats_total.formulas_failed += stats.formulas_failed
|
|
1096
|
+
async with progress_lock:
|
|
1097
|
+
progress.update(1)
|
|
1098
|
+
|
|
1099
|
+
await asyncio.gather(*(runner(path) for path in paths))
|
|
1100
|
+
return stats_total
|
|
1101
|
+
|
|
1102
|
+
try:
|
|
1103
|
+
stats = asyncio.run(run())
|
|
1104
|
+
finally:
|
|
1105
|
+
progress.close()
|
|
1106
|
+
formula_progress.close()
|
|
1107
|
+
|
|
1108
|
+
if report_target and error_records:
|
|
1109
|
+
report_target.parent.mkdir(parents=True, exist_ok=True)
|
|
1110
|
+
report_target.write_text(
|
|
1111
|
+
json.dumps(error_records, ensure_ascii=False, indent=2) + "\n",
|
|
1112
|
+
encoding="utf-8",
|
|
611
1113
|
)
|
|
1114
|
+
|
|
1115
|
+
rows = [
|
|
1116
|
+
("Mode", "json" if json_mode else "markdown"),
|
|
1117
|
+
("Inputs", str(len(paths))),
|
|
1118
|
+
("Outputs", str(len(output_map) if not only_show_error else 0)),
|
|
1119
|
+
("Formulas", str(stats.formulas_total)),
|
|
1120
|
+
("Invalid", str(stats.formulas_invalid)),
|
|
1121
|
+
("Cleaned", str(stats.formulas_cleaned)),
|
|
1122
|
+
("Repaired", str(stats.formulas_repaired)),
|
|
1123
|
+
("Failed", str(stats.formulas_failed)),
|
|
1124
|
+
("Only show error", "yes" if only_show_error else "no"),
|
|
1125
|
+
("Report", _relative_path(report_target) if report_target else "-"),
|
|
1126
|
+
]
|
|
1127
|
+
_print_summary("recognize fix-math", rows)
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
@recognize.command("fix-mermaid")
|
|
1131
|
+
@click.option("-c", "--config", "config_path", default="config.toml", help="Path to config.toml")
|
|
1132
|
+
@click.option(
|
|
1133
|
+
"-i",
|
|
1134
|
+
"--input",
|
|
1135
|
+
"inputs",
|
|
1136
|
+
multiple=True,
|
|
1137
|
+
required=True,
|
|
1138
|
+
help="Input markdown or JSON file/directory (repeatable)",
|
|
1139
|
+
)
|
|
1140
|
+
@click.option("-o", "--output", "output_dir", default=None, help="Output directory")
|
|
1141
|
+
@click.option("--in-place", "in_place", is_flag=True, help="Fix Mermaid blocks in place")
|
|
1142
|
+
@click.option("-r", "--recursive", is_flag=True, help="Recursively discover files")
|
|
1143
|
+
@click.option("--json", "json_mode", is_flag=True, help="Process JSON inputs instead of markdown")
|
|
1144
|
+
@click.option("-m", "--model", "model_ref", required=True, help="provider/model")
|
|
1145
|
+
@click.option("--batch-size", "batch_size", default=10, show_default=True, type=int)
|
|
1146
|
+
@click.option("--context-chars", "context_chars", default=80, show_default=True, type=int)
|
|
1147
|
+
@click.option("--max-retries", "max_retries", default=3, show_default=True, type=int)
|
|
1148
|
+
@click.option("--workers", type=int, default=4, show_default=True, help="Concurrent workers")
|
|
1149
|
+
@click.option("--timeout", "timeout", default=120.0, show_default=True, type=float)
|
|
1150
|
+
@click.option(
|
|
1151
|
+
"--only-show-error",
|
|
1152
|
+
"only_show_error",
|
|
1153
|
+
is_flag=True,
|
|
1154
|
+
help="Only validate Mermaid blocks and report error counts",
|
|
1155
|
+
)
|
|
1156
|
+
@click.option("--report", "report_path", default=None, help="Error report output path")
|
|
1157
|
+
@click.option("--dry-run", is_flag=True, help="Report actions without writing files")
|
|
1158
|
+
@click.option("-v", "--verbose", is_flag=True, help="Enable verbose logging")
|
|
1159
|
+
def recognize_fix_mermaid(
|
|
1160
|
+
config_path: str,
|
|
1161
|
+
inputs: tuple[str, ...],
|
|
1162
|
+
output_dir: str | None,
|
|
1163
|
+
in_place: bool,
|
|
1164
|
+
recursive: bool,
|
|
1165
|
+
json_mode: bool,
|
|
1166
|
+
model_ref: str,
|
|
1167
|
+
batch_size: int,
|
|
1168
|
+
context_chars: int,
|
|
1169
|
+
max_retries: int,
|
|
1170
|
+
workers: int,
|
|
1171
|
+
timeout: float,
|
|
1172
|
+
only_show_error: bool,
|
|
1173
|
+
report_path: str | None,
|
|
1174
|
+
dry_run: bool,
|
|
1175
|
+
verbose: bool,
|
|
1176
|
+
) -> None:
|
|
1177
|
+
"""Validate and repair Mermaid diagrams in markdown or JSON outputs."""
|
|
1178
|
+
configure_logging(verbose)
|
|
1179
|
+
if in_place and output_dir:
|
|
1180
|
+
raise click.ClickException("--in-place cannot be used with --output")
|
|
1181
|
+
if not only_show_error and not in_place and not output_dir:
|
|
1182
|
+
raise click.ClickException("Either --in-place or --output is required")
|
|
1183
|
+
if batch_size <= 0:
|
|
1184
|
+
raise click.ClickException("--batch-size must be positive")
|
|
1185
|
+
if context_chars < 0:
|
|
1186
|
+
raise click.ClickException("--context-chars must be non-negative")
|
|
1187
|
+
if max_retries < 0:
|
|
1188
|
+
raise click.ClickException("--max-retries must be non-negative")
|
|
1189
|
+
if workers <= 0:
|
|
1190
|
+
raise click.ClickException("--workers must be positive")
|
|
1191
|
+
try:
|
|
1192
|
+
require_mmdc()
|
|
1193
|
+
except RuntimeError as exc:
|
|
1194
|
+
raise click.ClickException(str(exc)) from exc
|
|
1195
|
+
|
|
1196
|
+
if not json_mode:
|
|
1197
|
+
file_types: set[str] = set()
|
|
1198
|
+
for raw in inputs:
|
|
1199
|
+
path = Path(raw)
|
|
1200
|
+
if path.is_file():
|
|
1201
|
+
suffix = path.suffix.lower()
|
|
1202
|
+
if suffix in {".md", ".json"}:
|
|
1203
|
+
file_types.add(suffix)
|
|
1204
|
+
if ".md" in file_types and ".json" in file_types:
|
|
1205
|
+
raise click.ClickException(
|
|
1206
|
+
"Mixed markdown and JSON inputs. Use --json for JSON or split commands."
|
|
1207
|
+
)
|
|
1208
|
+
if ".json" in file_types:
|
|
1209
|
+
json_mode = True
|
|
1210
|
+
logger.info("Detected JSON inputs; enabling --json mode")
|
|
1211
|
+
|
|
1212
|
+
config = load_config(config_path)
|
|
1213
|
+
provider, model_name = parse_model_ref(model_ref, config.providers)
|
|
1214
|
+
api_keys = resolve_api_keys(provider.api_keys)
|
|
1215
|
+
if provider.type in {
|
|
1216
|
+
"openai_compatible",
|
|
1217
|
+
"dashscope",
|
|
1218
|
+
"gemini_ai_studio",
|
|
1219
|
+
"azure_openai",
|
|
1220
|
+
"claude",
|
|
1221
|
+
} and not api_keys:
|
|
1222
|
+
raise click.ClickException(f"{provider.type} providers require api_keys")
|
|
1223
|
+
api_key = api_keys[0] if api_keys else None
|
|
1224
|
+
|
|
1225
|
+
if json_mode:
|
|
1226
|
+
paths = discover_json(inputs, recursive=recursive)
|
|
1227
|
+
else:
|
|
1228
|
+
paths = discover_markdown(inputs, None, recursive=recursive)
|
|
1229
|
+
if not paths:
|
|
1230
|
+
click.echo("No files discovered")
|
|
1231
|
+
return
|
|
1232
|
+
|
|
1233
|
+
output_path = Path(output_dir) if output_dir else None
|
|
1234
|
+
if output_path and not dry_run and not only_show_error:
|
|
1235
|
+
output_path = _ensure_output_dir(output_dir)
|
|
1236
|
+
_warn_if_not_empty(output_path)
|
|
1237
|
+
|
|
1238
|
+
if in_place:
|
|
1239
|
+
output_map = {path: path for path in paths}
|
|
1240
|
+
elif output_path:
|
|
1241
|
+
ext = ".json" if json_mode else ".md"
|
|
1242
|
+
output_map = {
|
|
1243
|
+
path: (output_path / name)
|
|
1244
|
+
for path, name in _map_output_files(paths, [output_path], ext=ext).items()
|
|
1245
|
+
}
|
|
1246
|
+
else:
|
|
1247
|
+
output_map = {path: path for path in paths}
|
|
1248
|
+
|
|
1249
|
+
report_target = None
|
|
1250
|
+
if report_path:
|
|
1251
|
+
report_target = Path(report_path)
|
|
1252
|
+
elif not only_show_error:
|
|
1253
|
+
if output_path:
|
|
1254
|
+
report_target = output_path / "fix-mermaid-errors.json"
|
|
1255
|
+
elif in_place:
|
|
1256
|
+
report_target = Path.cwd() / "fix-mermaid-errors.json"
|
|
1257
|
+
|
|
1258
|
+
if dry_run and not only_show_error:
|
|
1259
|
+
rows = [
|
|
1260
|
+
("Mode", "json" if json_mode else "markdown"),
|
|
1261
|
+
("Inputs", str(len(paths))),
|
|
1262
|
+
("Outputs", str(len(output_map))),
|
|
1263
|
+
("Batch size", str(batch_size)),
|
|
1264
|
+
("Context chars", str(context_chars)),
|
|
1265
|
+
("Max retries", str(max_retries)),
|
|
1266
|
+
("Workers", str(workers)),
|
|
1267
|
+
("Timeout", f"{timeout:.1f}s"),
|
|
1268
|
+
("Only show error", "yes" if only_show_error else "no"),
|
|
1269
|
+
("In place", "yes" if in_place else "no"),
|
|
1270
|
+
("Output dir", _relative_path(output_path) if output_path else "-"),
|
|
1271
|
+
("Report", _relative_path(report_target) if report_target else "-"),
|
|
1272
|
+
]
|
|
1273
|
+
_print_summary("recognize fix-mermaid (dry-run)", rows)
|
|
1274
|
+
return
|
|
1275
|
+
|
|
1276
|
+
progress = tqdm(total=len(paths), desc="fix-mermaid", unit="file")
|
|
1277
|
+
diagram_progress = tqdm(total=0, desc="diagrams", unit="diagram")
|
|
1278
|
+
error_records: list[dict[str, Any]] = []
|
|
1279
|
+
|
|
1280
|
+
async def run() -> MermaidFixStats:
|
|
1281
|
+
semaphore = asyncio.Semaphore(workers)
|
|
1282
|
+
progress_lock = asyncio.Lock()
|
|
1283
|
+
stats_total = MermaidFixStats()
|
|
1284
|
+
|
|
1285
|
+
async with httpx.AsyncClient() as client:
|
|
1286
|
+
async def handle_path(path: Path) -> MermaidFixStats:
|
|
1287
|
+
stats = MermaidFixStats()
|
|
1288
|
+
if json_mode:
|
|
1289
|
+
raw_text = read_text(path)
|
|
1290
|
+
items, payload, template_tag = _load_json_payload(path)
|
|
1291
|
+
cursor = 0
|
|
1292
|
+
for item_index, item in enumerate(items):
|
|
1293
|
+
if not isinstance(item, dict):
|
|
1294
|
+
continue
|
|
1295
|
+
template = _resolve_item_template(item, template_tag)
|
|
1296
|
+
fields = _template_markdown_fields(template)
|
|
1297
|
+
for field in fields:
|
|
1298
|
+
value = item.get(field)
|
|
1299
|
+
if not isinstance(value, str):
|
|
1300
|
+
continue
|
|
1301
|
+
spans = extract_mermaid_spans(value, context_chars)
|
|
1302
|
+
if spans:
|
|
1303
|
+
diagram_progress.total += len(spans)
|
|
1304
|
+
diagram_progress.refresh()
|
|
1305
|
+
line_start, cursor = locate_json_field_start(raw_text, value, cursor)
|
|
1306
|
+
field_path = f"papers[{item_index}].{field}"
|
|
1307
|
+
updated, errors = await fix_mermaid_text(
|
|
1308
|
+
value,
|
|
1309
|
+
str(path),
|
|
1310
|
+
line_start,
|
|
1311
|
+
field_path,
|
|
1312
|
+
item_index,
|
|
1313
|
+
provider,
|
|
1314
|
+
model_name,
|
|
1315
|
+
api_key,
|
|
1316
|
+
timeout,
|
|
1317
|
+
max_retries,
|
|
1318
|
+
batch_size,
|
|
1319
|
+
context_chars,
|
|
1320
|
+
client,
|
|
1321
|
+
stats,
|
|
1322
|
+
repair_enabled=not only_show_error,
|
|
1323
|
+
spans=spans,
|
|
1324
|
+
progress_cb=lambda: diagram_progress.update(1),
|
|
1325
|
+
)
|
|
1326
|
+
if not only_show_error and updated != value:
|
|
1327
|
+
item[field] = updated
|
|
1328
|
+
error_records.extend(errors)
|
|
1329
|
+
if not only_show_error:
|
|
1330
|
+
output_data: Any = items if payload is None else {**payload, "papers": items}
|
|
1331
|
+
output_path = output_map[path]
|
|
1332
|
+
serialized = json.dumps(output_data, ensure_ascii=False, indent=2)
|
|
1333
|
+
await asyncio.to_thread(output_path.write_text, f"{serialized}\n", encoding="utf-8")
|
|
1334
|
+
else:
|
|
1335
|
+
content = await asyncio.to_thread(read_text, path)
|
|
1336
|
+
spans = extract_mermaid_spans(content, context_chars)
|
|
1337
|
+
if spans:
|
|
1338
|
+
diagram_progress.total += len(spans)
|
|
1339
|
+
diagram_progress.refresh()
|
|
1340
|
+
updated, errors = await fix_mermaid_text(
|
|
1341
|
+
content,
|
|
1342
|
+
str(path),
|
|
1343
|
+
1,
|
|
1344
|
+
None,
|
|
1345
|
+
None,
|
|
1346
|
+
provider,
|
|
1347
|
+
model_name,
|
|
1348
|
+
api_key,
|
|
1349
|
+
timeout,
|
|
1350
|
+
max_retries,
|
|
1351
|
+
batch_size,
|
|
1352
|
+
context_chars,
|
|
1353
|
+
client,
|
|
1354
|
+
stats,
|
|
1355
|
+
repair_enabled=not only_show_error,
|
|
1356
|
+
spans=spans,
|
|
1357
|
+
progress_cb=lambda: diagram_progress.update(1),
|
|
1358
|
+
)
|
|
1359
|
+
if not only_show_error:
|
|
1360
|
+
output_path = output_map[path]
|
|
1361
|
+
await asyncio.to_thread(output_path.write_text, updated, encoding="utf-8")
|
|
1362
|
+
error_records.extend(errors)
|
|
1363
|
+
return stats
|
|
1364
|
+
|
|
1365
|
+
async def runner(path: Path) -> None:
|
|
1366
|
+
async with semaphore:
|
|
1367
|
+
stats = await handle_path(path)
|
|
1368
|
+
stats_total.diagrams_total += stats.diagrams_total
|
|
1369
|
+
stats_total.diagrams_invalid += stats.diagrams_invalid
|
|
1370
|
+
stats_total.diagrams_repaired += stats.diagrams_repaired
|
|
1371
|
+
stats_total.diagrams_failed += stats.diagrams_failed
|
|
1372
|
+
async with progress_lock:
|
|
1373
|
+
progress.update(1)
|
|
1374
|
+
|
|
1375
|
+
await asyncio.gather(*(runner(path) for path in paths))
|
|
1376
|
+
return stats_total
|
|
1377
|
+
|
|
1378
|
+
try:
|
|
1379
|
+
stats = asyncio.run(run())
|
|
612
1380
|
finally:
|
|
613
1381
|
progress.close()
|
|
1382
|
+
diagram_progress.close()
|
|
1383
|
+
|
|
1384
|
+
if report_target and error_records:
|
|
1385
|
+
report_target.parent.mkdir(parents=True, exist_ok=True)
|
|
1386
|
+
report_target.write_text(
|
|
1387
|
+
json.dumps(error_records, ensure_ascii=False, indent=2) + "\n",
|
|
1388
|
+
encoding="utf-8",
|
|
1389
|
+
)
|
|
1390
|
+
|
|
614
1391
|
rows = [
|
|
1392
|
+
("Mode", "json" if json_mode else "markdown"),
|
|
615
1393
|
("Inputs", str(len(paths))),
|
|
616
|
-
("Outputs", str(len(output_map))),
|
|
617
|
-
("
|
|
618
|
-
("
|
|
619
|
-
("
|
|
620
|
-
("
|
|
621
|
-
("
|
|
1394
|
+
("Outputs", str(len(output_map) if not only_show_error else 0)),
|
|
1395
|
+
("Diagrams", str(stats.diagrams_total)),
|
|
1396
|
+
("Invalid", str(stats.diagrams_invalid)),
|
|
1397
|
+
("Repaired", str(stats.diagrams_repaired)),
|
|
1398
|
+
("Failed", str(stats.diagrams_failed)),
|
|
1399
|
+
("Only show error", "yes" if only_show_error else "no"),
|
|
1400
|
+
("Report", _relative_path(report_target) if report_target else "-"),
|
|
622
1401
|
]
|
|
623
|
-
_print_summary("recognize fix", rows)
|
|
1402
|
+
_print_summary("recognize fix-mermaid", rows)
|