deepresearch-flow 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. deepresearch_flow/paper/db.py +184 -0
  2. deepresearch_flow/paper/db_ops.py +1939 -0
  3. deepresearch_flow/paper/web/app.py +38 -3705
  4. deepresearch_flow/paper/web/constants.py +23 -0
  5. deepresearch_flow/paper/web/filters.py +255 -0
  6. deepresearch_flow/paper/web/handlers/__init__.py +14 -0
  7. deepresearch_flow/paper/web/handlers/api.py +217 -0
  8. deepresearch_flow/paper/web/handlers/pages.py +334 -0
  9. deepresearch_flow/paper/web/markdown.py +549 -0
  10. deepresearch_flow/paper/web/static/css/main.css +857 -0
  11. deepresearch_flow/paper/web/static/js/detail.js +406 -0
  12. deepresearch_flow/paper/web/static/js/index.js +266 -0
  13. deepresearch_flow/paper/web/static/js/outline.js +58 -0
  14. deepresearch_flow/paper/web/static/js/stats.js +39 -0
  15. deepresearch_flow/paper/web/templates/base.html +43 -0
  16. deepresearch_flow/paper/web/templates/detail.html +332 -0
  17. deepresearch_flow/paper/web/templates/index.html +114 -0
  18. deepresearch_flow/paper/web/templates/stats.html +29 -0
  19. deepresearch_flow/paper/web/templates.py +85 -0
  20. deepresearch_flow/paper/web/text.py +68 -0
  21. deepresearch_flow/recognize/cli.py +805 -26
  22. deepresearch_flow/recognize/katex_check.js +29 -0
  23. deepresearch_flow/recognize/math.py +719 -0
  24. deepresearch_flow/recognize/mermaid.py +690 -0
  25. {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/METADATA +78 -4
  26. {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/RECORD +30 -9
  27. {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/WHEEL +0 -0
  28. {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/entry_points.txt +0 -0
  29. {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
  30. {deepresearch_flow-0.3.0.dist-info → deepresearch_flow-0.4.1.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,11 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import asyncio
6
+ import json
6
7
  import logging
7
8
  import time
8
9
  from pathlib import Path
9
- from typing import Awaitable, Callable, Iterable
10
+ from typing import Any, Awaitable, Callable, Iterable
10
11
 
11
12
  import click
12
13
  import coloredlogs
@@ -15,6 +16,9 @@ from rich.console import Console
15
16
  from rich.table import Table
16
17
  from tqdm import tqdm
17
18
 
19
+ from deepresearch_flow.paper.config import load_config, resolve_api_keys
20
+ from deepresearch_flow.paper.extract import parse_model_ref
21
+ from deepresearch_flow.paper.template_registry import get_stage_definitions
18
22
  from deepresearch_flow.paper.utils import discover_markdown
19
23
  from deepresearch_flow.recognize.markdown import (
20
24
  DEFAULT_USER_AGENT,
@@ -26,6 +30,19 @@ from deepresearch_flow.recognize.markdown import (
26
30
  sanitize_filename,
27
31
  unpack_markdown_images,
28
32
  )
33
+ from deepresearch_flow.recognize.math import (
34
+ MathFixStats,
35
+ extract_math_spans,
36
+ fix_math_text,
37
+ locate_json_field_start,
38
+ require_pylatexenc,
39
+ )
40
+ from deepresearch_flow.recognize.mermaid import (
41
+ MermaidFixStats,
42
+ extract_mermaid_spans,
43
+ fix_mermaid_text,
44
+ require_mmdc,
45
+ )
29
46
  from deepresearch_flow.recognize.organize import (
30
47
  discover_mineru_dirs,
31
48
  fix_markdown_text,
@@ -72,23 +89,28 @@ def _unique_output_filename(
72
89
  base: str,
73
90
  output_dirs: Iterable[Path],
74
91
  used: set[str],
92
+ ext: str,
75
93
  ) -> str:
76
94
  base = sanitize_filename(base) or "document"
77
- candidate = f"{base}.md"
95
+ candidate = f"{base}{ext}"
78
96
  counter = 0
79
97
  while candidate in used or any((directory / candidate).exists() for directory in output_dirs):
80
98
  counter += 1
81
- candidate = f"{base}_{counter}.md"
99
+ candidate = f"{base}_{counter}{ext}"
82
100
  used.add(candidate)
83
101
  return candidate
84
102
 
85
103
 
86
- def _map_output_files(paths: Iterable[Path], output_dirs: list[Path]) -> dict[Path, str]:
104
+ def _map_output_files(
105
+ paths: Iterable[Path],
106
+ output_dirs: list[Path],
107
+ ext: str = ".md",
108
+ ) -> dict[Path, str]:
87
109
  used: set[str] = set()
88
110
  mapping: dict[Path, str] = {}
89
111
  for path in paths:
90
112
  base = path.stem
91
- mapping[path] = _unique_output_filename(base, output_dirs, used)
113
+ mapping[path] = _unique_output_filename(base, output_dirs, used, ext)
92
114
  return mapping
93
115
 
94
116
 
@@ -112,6 +134,93 @@ def _format_duration(seconds: float) -> str:
112
134
  return f"{int(hours)}h {int(minutes)}m {remainder:.1f}s"
113
135
 
114
136
 
137
+ def _resolve_item_template(item: dict[str, Any], default_template: str | None) -> str | None:
138
+ raw = item.get("template_tag") or item.get("prompt_template") or default_template
139
+ if isinstance(raw, str) and raw:
140
+ return raw
141
+ return None
142
+
143
+
144
+ def _template_markdown_fields(template: str | None) -> list[str]:
145
+ if template:
146
+ stages = get_stage_definitions(template)
147
+ if stages:
148
+ return [field for stage in stages for field in stage.fields]
149
+ return ["summary", "abstract"]
150
+
151
+
152
+ def discover_json(inputs: Iterable[str], recursive: bool) -> list[Path]:
153
+ files: set[Path] = set()
154
+ for raw in inputs:
155
+ path = Path(raw)
156
+ if path.is_file():
157
+ if path.suffix.lower() != ".json":
158
+ raise ValueError(f"Input file is not a json file: {path}")
159
+ files.add(path.resolve())
160
+ continue
161
+
162
+ if path.is_dir():
163
+ pattern = path.rglob("*.json") if recursive else path.glob("*.json")
164
+ for match in pattern:
165
+ if match.is_file():
166
+ files.add(match.resolve())
167
+ continue
168
+
169
+ raise FileNotFoundError(f"Input path not found: {path}")
170
+
171
+ return sorted(files)
172
+
173
+
174
+ def _load_json_payload(path: Path) -> tuple[list[Any], dict[str, Any] | None, str | None]:
175
+ try:
176
+ data = json.loads(read_text(path))
177
+ except json.JSONDecodeError as exc:
178
+ raise click.ClickException(f"Invalid JSON in {path}: {exc}") from exc
179
+
180
+ if isinstance(data, list):
181
+ return data, None, None
182
+ if isinstance(data, dict):
183
+ papers = data.get("papers")
184
+ if isinstance(papers, list):
185
+ template_tag = data.get("template_tag")
186
+ return papers, data, template_tag if isinstance(template_tag, str) else None
187
+ raise click.ClickException(f"JSON object missing 'papers' list: {path}")
188
+
189
+ raise click.ClickException(f"Unsupported JSON structure in {path}")
190
+
191
+
192
+ async def _fix_json_items(
193
+ items: list[Any],
194
+ default_template: str | None,
195
+ fix_level: str,
196
+ format_enabled: bool,
197
+ ) -> tuple[int, int, int, int]:
198
+ items_total = 0
199
+ items_updated = 0
200
+ fields_total = 0
201
+ fields_updated = 0
202
+ for item in items:
203
+ if not isinstance(item, dict):
204
+ continue
205
+ items_total += 1
206
+ template = _resolve_item_template(item, default_template)
207
+ fields = _template_markdown_fields(template)
208
+ item_updated = False
209
+ for field in fields:
210
+ value = item.get(field)
211
+ if not isinstance(value, str):
212
+ continue
213
+ fields_total += 1
214
+ updated = await fix_markdown_text(value, fix_level, format_enabled)
215
+ if updated != value:
216
+ item[field] = updated
217
+ fields_updated += 1
218
+ item_updated = True
219
+ if item_updated:
220
+ items_updated += 1
221
+ return items_total, items_updated, fields_total, fields_updated
222
+
223
+
115
224
  async def _run_with_workers(
116
225
  items: Iterable[Path],
117
226
  workers: int,
@@ -226,6 +335,46 @@ async def _run_fix(
226
335
  await _run_with_workers(paths, workers, handler, progress=progress)
227
336
 
228
337
 
338
+ async def _run_fix_json(
339
+ paths: list[Path],
340
+ output_map: dict[Path, Path],
341
+ fix_level: str,
342
+ format_enabled: bool,
343
+ workers: int,
344
+ progress: tqdm | None,
345
+ ) -> list[tuple[int, int, int, int, int]]:
346
+ semaphore = asyncio.Semaphore(workers)
347
+ progress_lock = asyncio.Lock() if progress else None
348
+ results: list[tuple[int, int, int, int, int]] = []
349
+
350
+ async def handler(path: Path) -> tuple[int, int, int, int, int]:
351
+ items, payload, template_tag = _load_json_payload(path)
352
+ items_total, items_updated, fields_total, fields_updated = await _fix_json_items(
353
+ items, template_tag, fix_level, format_enabled
354
+ )
355
+ output_data: Any
356
+ if payload is None:
357
+ output_data = items
358
+ else:
359
+ payload["papers"] = items
360
+ output_data = payload
361
+ output_path = output_map[path]
362
+ serialized = json.dumps(output_data, ensure_ascii=False, indent=2)
363
+ await asyncio.to_thread(output_path.write_text, f"{serialized}\n", encoding="utf-8")
364
+ return len(items), items_total, items_updated, fields_total, fields_updated
365
+
366
+ async def runner(path: Path) -> None:
367
+ async with semaphore:
368
+ result = await handler(path)
369
+ results.append(result)
370
+ if progress and progress_lock:
371
+ async with progress_lock:
372
+ progress.update(1)
373
+
374
+ await asyncio.gather(*(runner(path) for path in paths))
375
+ return results
376
+
377
+
229
378
  @click.group()
230
379
  def recognize() -> None:
231
380
  """OCR recognition and Markdown post-processing commands."""
@@ -530,11 +679,12 @@ def organize(
530
679
  "inputs",
531
680
  multiple=True,
532
681
  required=True,
533
- help="Input markdown file or directory (repeatable)",
682
+ help="Input markdown or JSON file/directory (repeatable)",
534
683
  )
535
684
  @click.option("-o", "--output", "output_dir", default=None, help="Output directory")
536
685
  @click.option("--in-place", "in_place", is_flag=True, help="Fix markdown files in place")
537
- @click.option("-r", "--recursive", is_flag=True, help="Recursively discover markdown files")
686
+ @click.option("-r", "--recursive", is_flag=True, help="Recursively discover files")
687
+ @click.option("--json", "json_mode", is_flag=True, help="Fix markdown fields inside JSON outputs")
538
688
  @click.option(
539
689
  "--fix-level",
540
690
  "fix_level",
@@ -552,13 +702,14 @@ def recognize_fix(
552
702
  output_dir: str | None,
553
703
  in_place: bool,
554
704
  recursive: bool,
705
+ json_mode: bool,
555
706
  fix_level: str,
556
707
  no_format: bool,
557
708
  workers: int,
558
709
  dry_run: bool,
559
710
  verbose: bool,
560
711
  ) -> None:
561
- """Fix and format OCR markdown outputs."""
712
+ """Fix and format OCR markdown outputs (markdown or JSON)."""
562
713
  configure_logging(verbose)
563
714
  start_time = time.monotonic()
564
715
  if workers <= 0:
@@ -573,19 +724,58 @@ def recognize_fix(
573
724
  output_path = _ensure_output_dir(output_dir)
574
725
  _warn_if_not_empty(output_path)
575
726
 
576
- paths = discover_markdown(inputs, None, recursive=recursive)
727
+ if json_mode:
728
+ paths = discover_json(inputs, recursive=recursive)
729
+ else:
730
+ json_inputs: list[str] = []
731
+ md_inputs: list[str] = []
732
+ for raw in inputs:
733
+ path = Path(raw)
734
+ if path.is_file():
735
+ suffix = path.suffix.lower()
736
+ if suffix == ".json":
737
+ json_inputs.append(raw)
738
+ continue
739
+ if suffix == ".md":
740
+ md_inputs.append(raw)
741
+ continue
742
+ raise click.ClickException(f"Input file must be .md or .json: {path}")
743
+ if path.is_dir():
744
+ json_inputs.append(raw)
745
+ md_inputs.append(raw)
746
+ continue
747
+ raise click.ClickException(f"Input path not found: {path}")
748
+ json_paths = discover_json(json_inputs, recursive=recursive) if json_inputs else []
749
+ md_paths = discover_markdown(md_inputs, None, recursive=recursive) if md_inputs else []
750
+ if json_paths and not md_paths:
751
+ json_mode = True
752
+ paths = json_paths
753
+ click.echo("Detected JSON inputs; enabling --json mode")
754
+ elif md_paths and not json_paths:
755
+ paths = md_paths
756
+ elif json_paths and md_paths:
757
+ raise click.ClickException(
758
+ "Found both markdown and JSON inputs; split inputs or pass --json explicitly"
759
+ )
760
+ else:
761
+ paths = []
577
762
  if not paths:
578
- click.echo("No markdown files discovered")
763
+ click.echo("No files discovered")
579
764
  return
580
765
 
581
766
  format_enabled = not no_format
582
767
  if in_place:
583
768
  output_map = {path: path for path in paths}
584
769
  else:
585
- output_map = {path: (output_path / name) for path, name in _map_output_files(paths, [output_path]).items()}
770
+ ext = ".json" if json_mode else ".md"
771
+ output_map = {
772
+ path: (output_path / name)
773
+ for path, name in _map_output_files(paths, [output_path], ext=ext).items()
774
+ }
586
775
 
587
776
  if dry_run:
588
777
  rows = [
778
+ ("Mode", "json" if json_mode else "markdown"),
589
779
  ("Inputs", str(len(paths))),
590
780
  ("Outputs", str(len(output_map))),
591
781
  ("Fix level", fix_level),
@@ -599,25 +789,614 @@ def recognize_fix(
599
789
 
600
790
  progress = tqdm(total=len(paths), desc="fix", unit="file")
601
791
  try:
602
- asyncio.run(
603
- _run_fix(
604
- paths,
605
- output_map,
606
- fix_level,
607
- format_enabled,
608
- workers,
609
- progress,
792
+ if json_mode:
793
+ results = asyncio.run(
794
+ _run_fix_json(
795
+ paths,
796
+ output_map,
797
+ fix_level,
798
+ format_enabled,
799
+ workers,
800
+ progress,
801
+ )
802
+ )
803
+ else:
804
+ asyncio.run(
805
+ _run_fix(
806
+ paths,
807
+ output_map,
808
+ fix_level,
809
+ format_enabled,
810
+ workers,
811
+ progress,
812
+ )
813
+ )
814
+ finally:
815
+ progress.close()
816
+ if json_mode:
817
+ total_items = sum(result[0] for result in results)
818
+ items_processed = sum(result[1] for result in results)
819
+ items_updated = sum(result[2] for result in results)
820
+ fields_total = sum(result[3] for result in results)
821
+ fields_updated = sum(result[4] for result in results)
822
+ items_skipped = total_items - items_processed
823
+ rows = [
824
+ ("Mode", "json"),
825
+ ("Inputs", str(len(paths))),
826
+ ("Outputs", str(len(output_map))),
827
+ ("Items", str(total_items)),
828
+ ("Items processed", str(items_processed)),
829
+ ("Items skipped", str(items_skipped)),
830
+ ("Items updated", str(items_updated)),
831
+ ("Fields processed", str(fields_total)),
832
+ ("Fields updated", str(fields_updated)),
833
+ ("Fix level", fix_level),
834
+ ("Format", "no" if no_format else "yes"),
835
+ ("In place", "yes" if in_place else "no"),
836
+ ("Output dir", _relative_path(output_path) if output_path else "-"),
837
+ ("Duration", _format_duration(time.monotonic() - start_time)),
838
+ ]
839
+ else:
840
+ rows = [
841
+ ("Mode", "markdown"),
842
+ ("Inputs", str(len(paths))),
843
+ ("Outputs", str(len(output_map))),
844
+ ("Fix level", fix_level),
845
+ ("Format", "no" if no_format else "yes"),
846
+ ("In place", "yes" if in_place else "no"),
847
+ ("Output dir", _relative_path(output_path) if output_path else "-"),
848
+ ("Duration", _format_duration(time.monotonic() - start_time)),
849
+ ]
850
+ _print_summary("recognize fix", rows)
851
+
852
+
853
+ @recognize.command("fix-math")
854
+ @click.option("-c", "--config", "config_path", default="config.toml", help="Path to config.toml")
855
+ @click.option(
856
+ "-i",
857
+ "--input",
858
+ "inputs",
859
+ multiple=True,
860
+ required=True,
861
+ help="Input markdown or JSON file/directory (repeatable)",
862
+ )
863
+ @click.option("-o", "--output", "output_dir", default=None, help="Output directory")
864
+ @click.option("--in-place", "in_place", is_flag=True, help="Fix formulas in place")
865
+ @click.option("-r", "--recursive", is_flag=True, help="Recursively discover files")
866
+ @click.option("--json", "json_mode", is_flag=True, help="Process JSON inputs instead of markdown")
867
+ @click.option("-m", "--model", "model_ref", required=True, help="provider/model")
868
+ @click.option("--batch-size", "batch_size", default=10, show_default=True, type=int)
869
+ @click.option("--context-chars", "context_chars", default=80, show_default=True, type=int)
870
+ @click.option("--max-retries", "max_retries", default=3, show_default=True, type=int)
871
+ @click.option("--workers", type=int, default=4, show_default=True, help="Concurrent workers")
872
+ @click.option("--timeout", "timeout", default=120.0, show_default=True, type=float)
873
+ @click.option(
874
+ "--only-show-error",
875
+ "only_show_error",
876
+ is_flag=True,
877
+ help="Only validate formulas and report error counts",
878
+ )
879
+ @click.option("--report", "report_path", default=None, help="Error report output path")
880
+ @click.option("--dry-run", is_flag=True, help="Report actions without writing files")
881
+ @click.option("-v", "--verbose", is_flag=True, help="Enable verbose logging")
882
+ def recognize_fix_math(
883
+ config_path: str,
884
+ inputs: tuple[str, ...],
885
+ output_dir: str | None,
886
+ in_place: bool,
887
+ recursive: bool,
888
+ json_mode: bool,
889
+ model_ref: str,
890
+ batch_size: int,
891
+ context_chars: int,
892
+ max_retries: int,
893
+ workers: int,
894
+ timeout: float,
895
+ only_show_error: bool,
896
+ report_path: str | None,
897
+ dry_run: bool,
898
+ verbose: bool,
899
+ ) -> None:
900
+ """Validate and repair LaTeX formulas in markdown or JSON outputs."""
901
+ configure_logging(verbose)
902
+ if in_place and output_dir:
903
+ raise click.ClickException("--in-place cannot be used with --output")
904
+ if not only_show_error and not in_place and not output_dir:
905
+ raise click.ClickException("Either --in-place or --output is required")
906
+ if batch_size <= 0:
907
+ raise click.ClickException("--batch-size must be positive")
908
+ if context_chars < 0:
909
+ raise click.ClickException("--context-chars must be non-negative")
910
+ if max_retries < 0:
911
+ raise click.ClickException("--max-retries must be non-negative")
912
+ if workers <= 0:
913
+ raise click.ClickException("--workers must be positive")
914
+ try:
915
+ require_pylatexenc()
916
+ except RuntimeError as exc:
917
+ raise click.ClickException(str(exc)) from exc
918
+
919
+ if not json_mode:
920
+ file_types: set[str] = set()
921
+ for raw in inputs:
922
+ path = Path(raw)
923
+ if path.is_file():
924
+ suffix = path.suffix.lower()
925
+ if suffix in {".md", ".json"}:
926
+ file_types.add(suffix)
927
+ if ".md" in file_types and ".json" in file_types:
928
+ raise click.ClickException(
929
+ "Mixed markdown and JSON inputs. Use --json for JSON or split commands."
610
930
  )
931
+ if ".json" in file_types:
932
+ json_mode = True
933
+ logger.info("Detected JSON inputs; enabling --json mode")
934
+
935
+ config = load_config(config_path)
936
+ provider, model_name = parse_model_ref(model_ref, config.providers)
937
+ api_keys = resolve_api_keys(provider.api_keys)
938
+ if provider.type in {
939
+ "openai_compatible",
940
+ "dashscope",
941
+ "gemini_ai_studio",
942
+ "azure_openai",
943
+ "claude",
944
+ } and not api_keys:
945
+ raise click.ClickException(f"{provider.type} providers require api_keys")
946
+ api_key = api_keys[0] if api_keys else None
947
+
948
+ if json_mode:
949
+ paths = discover_json(inputs, recursive=recursive)
950
+ else:
951
+ paths = discover_markdown(inputs, None, recursive=recursive)
952
+ if not paths:
953
+ click.echo("No files discovered")
954
+ return
955
+
956
+ output_path = Path(output_dir) if output_dir else None
957
+ if output_path and not dry_run and not only_show_error:
958
+ output_path = _ensure_output_dir(output_dir)
959
+ _warn_if_not_empty(output_path)
960
+
961
+ if in_place:
962
+ output_map = {path: path for path in paths}
963
+ elif output_path:
964
+ ext = ".json" if json_mode else ".md"
965
+ output_map = {
966
+ path: (output_path / name)
967
+ for path, name in _map_output_files(paths, [output_path], ext=ext).items()
968
+ }
969
+ else:
970
+ output_map = {path: path for path in paths}
971
+
972
+ report_target = None
973
+ if report_path:
974
+ report_target = Path(report_path)
975
+ elif not only_show_error:
976
+ if output_path:
977
+ report_target = output_path / "fix-math-errors.json"
978
+ elif in_place:
979
+ report_target = Path.cwd() / "fix-math-errors.json"
980
+
981
+ if dry_run and not only_show_error:
982
+ rows = [
983
+ ("Mode", "json" if json_mode else "markdown"),
984
+ ("Inputs", str(len(paths))),
985
+ ("Outputs", str(len(output_map))),
986
+ ("Batch size", str(batch_size)),
987
+ ("Context chars", str(context_chars)),
988
+ ("Max retries", str(max_retries)),
989
+ ("Workers", str(workers)),
990
+ ("Timeout", f"{timeout:.1f}s"),
991
+ ("Only show error", "yes" if only_show_error else "no"),
992
+ ("In place", "yes" if in_place else "no"),
993
+ ("Output dir", _relative_path(output_path) if output_path else "-"),
994
+ ("Report", _relative_path(report_target) if report_target else "-"),
995
+ ]
996
+ _print_summary("recognize fix-math (dry-run)", rows)
997
+ return
998
+
999
+ progress = tqdm(total=len(paths), desc="fix-math", unit="file")
1000
+ formula_progress = tqdm(total=0, desc="formulas", unit="formula")
1001
+ error_records: list[dict[str, Any]] = []
1002
+
1003
+ async def run() -> MathFixStats:
1004
+ semaphore = asyncio.Semaphore(workers)
1005
+ progress_lock = asyncio.Lock()
1006
+ stats_total = MathFixStats()
1007
+
1008
+ async with httpx.AsyncClient() as client:
1009
+ async def handle_path(path: Path) -> MathFixStats:
1010
+ stats = MathFixStats()
1011
+ if json_mode:
1012
+ raw_text = read_text(path)
1013
+ items, payload, template_tag = _load_json_payload(path)
1014
+ cursor = 0
1015
+ for item_index, item in enumerate(items):
1016
+ if not isinstance(item, dict):
1017
+ continue
1018
+ template = _resolve_item_template(item, template_tag)
1019
+ fields = _template_markdown_fields(template)
1020
+ for field in fields:
1021
+ value = item.get(field)
1022
+ if not isinstance(value, str):
1023
+ continue
1024
+ spans = extract_math_spans(value, context_chars)
1025
+ if spans:
1026
+ formula_progress.total += len(spans)
1027
+ formula_progress.refresh()
1028
+ line_start, cursor = locate_json_field_start(raw_text, value, cursor)
1029
+ field_path = f"papers[{item_index}].{field}"
1030
+ updated, errors = await fix_math_text(
1031
+ value,
1032
+ str(path),
1033
+ line_start,
1034
+ field_path,
1035
+ item_index,
1036
+ provider,
1037
+ model_name,
1038
+ api_key,
1039
+ timeout,
1040
+ max_retries,
1041
+ batch_size,
1042
+ context_chars,
1043
+ client,
1044
+ stats,
1045
+ repair_enabled=not only_show_error,
1046
+ spans=spans,
1047
+ progress_cb=lambda: formula_progress.update(1),
1048
+ )
1049
+ if not only_show_error and updated != value:
1050
+ item[field] = updated
1051
+ error_records.extend(errors)
1052
+ if not only_show_error:
1053
+ output_data: Any = items if payload is None else {**payload, "papers": items}
1054
+ output_path = output_map[path]
1055
+ serialized = json.dumps(output_data, ensure_ascii=False, indent=2)
1056
+ await asyncio.to_thread(output_path.write_text, f"{serialized}\n", encoding="utf-8")
1057
+ else:
1058
+ content = await asyncio.to_thread(read_text, path)
1059
+ spans = extract_math_spans(content, context_chars)
1060
+ if spans:
1061
+ formula_progress.total += len(spans)
1062
+ formula_progress.refresh()
1063
+ updated, errors = await fix_math_text(
1064
+ content,
1065
+ str(path),
1066
+ 1,
1067
+ None,
1068
+ None,
1069
+ provider,
1070
+ model_name,
1071
+ api_key,
1072
+ timeout,
1073
+ max_retries,
1074
+ batch_size,
1075
+ context_chars,
1076
+ client,
1077
+ stats,
1078
+ repair_enabled=not only_show_error,
1079
+ spans=spans,
1080
+ progress_cb=lambda: formula_progress.update(1),
1081
+ )
1082
+ if not only_show_error:
1083
+ output_path = output_map[path]
1084
+ await asyncio.to_thread(output_path.write_text, updated, encoding="utf-8")
1085
+ error_records.extend(errors)
1086
+ return stats
1087
+
1088
+ async def runner(path: Path) -> None:
1089
+ async with semaphore:
1090
+ stats = await handle_path(path)
1091
+ stats_total.formulas_total += stats.formulas_total
1092
+ stats_total.formulas_invalid += stats.formulas_invalid
1093
+ stats_total.formulas_cleaned += stats.formulas_cleaned
1094
+ stats_total.formulas_repaired += stats.formulas_repaired
1095
+ stats_total.formulas_failed += stats.formulas_failed
1096
+ async with progress_lock:
1097
+ progress.update(1)
1098
+
1099
+ await asyncio.gather(*(runner(path) for path in paths))
1100
+ return stats_total
1101
+
1102
+ try:
1103
+ stats = asyncio.run(run())
1104
+ finally:
1105
+ progress.close()
1106
+ formula_progress.close()
1107
+
1108
+ if report_target and error_records:
1109
+ report_target.parent.mkdir(parents=True, exist_ok=True)
1110
+ report_target.write_text(
1111
+ json.dumps(error_records, ensure_ascii=False, indent=2) + "\n",
1112
+ encoding="utf-8",
611
1113
  )
1114
+
1115
+ rows = [
1116
+ ("Mode", "json" if json_mode else "markdown"),
1117
+ ("Inputs", str(len(paths))),
1118
+ ("Outputs", str(len(output_map) if not only_show_error else 0)),
1119
+ ("Formulas", str(stats.formulas_total)),
1120
+ ("Invalid", str(stats.formulas_invalid)),
1121
+ ("Cleaned", str(stats.formulas_cleaned)),
1122
+ ("Repaired", str(stats.formulas_repaired)),
1123
+ ("Failed", str(stats.formulas_failed)),
1124
+ ("Only show error", "yes" if only_show_error else "no"),
1125
+ ("Report", _relative_path(report_target) if report_target else "-"),
1126
+ ]
1127
+ _print_summary("recognize fix-math", rows)
1128
+
1129
+
1130
+ @recognize.command("fix-mermaid")
1131
+ @click.option("-c", "--config", "config_path", default="config.toml", help="Path to config.toml")
1132
+ @click.option(
1133
+ "-i",
1134
+ "--input",
1135
+ "inputs",
1136
+ multiple=True,
1137
+ required=True,
1138
+ help="Input markdown or JSON file/directory (repeatable)",
1139
+ )
1140
+ @click.option("-o", "--output", "output_dir", default=None, help="Output directory")
1141
+ @click.option("--in-place", "in_place", is_flag=True, help="Fix Mermaid blocks in place")
1142
+ @click.option("-r", "--recursive", is_flag=True, help="Recursively discover files")
1143
+ @click.option("--json", "json_mode", is_flag=True, help="Process JSON inputs instead of markdown")
1144
+ @click.option("-m", "--model", "model_ref", required=True, help="provider/model")
1145
+ @click.option("--batch-size", "batch_size", default=10, show_default=True, type=int)
1146
+ @click.option("--context-chars", "context_chars", default=80, show_default=True, type=int)
1147
+ @click.option("--max-retries", "max_retries", default=3, show_default=True, type=int)
1148
+ @click.option("--workers", type=int, default=4, show_default=True, help="Concurrent workers")
1149
+ @click.option("--timeout", "timeout", default=120.0, show_default=True, type=float)
1150
+ @click.option(
1151
+ "--only-show-error",
1152
+ "only_show_error",
1153
+ is_flag=True,
1154
+ help="Only validate Mermaid blocks and report error counts",
1155
+ )
1156
+ @click.option("--report", "report_path", default=None, help="Error report output path")
1157
+ @click.option("--dry-run", is_flag=True, help="Report actions without writing files")
1158
+ @click.option("-v", "--verbose", is_flag=True, help="Enable verbose logging")
1159
+ def recognize_fix_mermaid(
1160
+ config_path: str,
1161
+ inputs: tuple[str, ...],
1162
+ output_dir: str | None,
1163
+ in_place: bool,
1164
+ recursive: bool,
1165
+ json_mode: bool,
1166
+ model_ref: str,
1167
+ batch_size: int,
1168
+ context_chars: int,
1169
+ max_retries: int,
1170
+ workers: int,
1171
+ timeout: float,
1172
+ only_show_error: bool,
1173
+ report_path: str | None,
1174
+ dry_run: bool,
1175
+ verbose: bool,
1176
+ ) -> None:
1177
+ """Validate and repair Mermaid diagrams in markdown or JSON outputs."""
1178
+ configure_logging(verbose)
1179
+ if in_place and output_dir:
1180
+ raise click.ClickException("--in-place cannot be used with --output")
1181
+ if not only_show_error and not in_place and not output_dir:
1182
+ raise click.ClickException("Either --in-place or --output is required")
1183
+ if batch_size <= 0:
1184
+ raise click.ClickException("--batch-size must be positive")
1185
+ if context_chars < 0:
1186
+ raise click.ClickException("--context-chars must be non-negative")
1187
+ if max_retries < 0:
1188
+ raise click.ClickException("--max-retries must be non-negative")
1189
+ if workers <= 0:
1190
+ raise click.ClickException("--workers must be positive")
1191
+ try:
1192
+ require_mmdc()
1193
+ except RuntimeError as exc:
1194
+ raise click.ClickException(str(exc)) from exc
1195
+
1196
+ if not json_mode:
1197
+ file_types: set[str] = set()
1198
+ for raw in inputs:
1199
+ path = Path(raw)
1200
+ if path.is_file():
1201
+ suffix = path.suffix.lower()
1202
+ if suffix in {".md", ".json"}:
1203
+ file_types.add(suffix)
1204
+ if ".md" in file_types and ".json" in file_types:
1205
+ raise click.ClickException(
1206
+ "Mixed markdown and JSON inputs. Use --json for JSON or split commands."
1207
+ )
1208
+ if ".json" in file_types:
1209
+ json_mode = True
1210
+ logger.info("Detected JSON inputs; enabling --json mode")
1211
+
1212
+ config = load_config(config_path)
1213
+ provider, model_name = parse_model_ref(model_ref, config.providers)
1214
+ api_keys = resolve_api_keys(provider.api_keys)
1215
+ if provider.type in {
1216
+ "openai_compatible",
1217
+ "dashscope",
1218
+ "gemini_ai_studio",
1219
+ "azure_openai",
1220
+ "claude",
1221
+ } and not api_keys:
1222
+ raise click.ClickException(f"{provider.type} providers require api_keys")
1223
+ api_key = api_keys[0] if api_keys else None
1224
+
1225
+ if json_mode:
1226
+ paths = discover_json(inputs, recursive=recursive)
1227
+ else:
1228
+ paths = discover_markdown(inputs, None, recursive=recursive)
1229
+ if not paths:
1230
+ click.echo("No files discovered")
1231
+ return
1232
+
1233
+ output_path = Path(output_dir) if output_dir else None
1234
+ if output_path and not dry_run and not only_show_error:
1235
+ output_path = _ensure_output_dir(output_dir)
1236
+ _warn_if_not_empty(output_path)
1237
+
1238
+ if in_place:
1239
+ output_map = {path: path for path in paths}
1240
+ elif output_path:
1241
+ ext = ".json" if json_mode else ".md"
1242
+ output_map = {
1243
+ path: (output_path / name)
1244
+ for path, name in _map_output_files(paths, [output_path], ext=ext).items()
1245
+ }
1246
+ else:
1247
+ output_map = {path: path for path in paths}
1248
+
1249
+ report_target = None
1250
+ if report_path:
1251
+ report_target = Path(report_path)
1252
+ elif not only_show_error:
1253
+ if output_path:
1254
+ report_target = output_path / "fix-mermaid-errors.json"
1255
+ elif in_place:
1256
+ report_target = Path.cwd() / "fix-mermaid-errors.json"
1257
+
1258
+ if dry_run and not only_show_error:
1259
+ rows = [
1260
+ ("Mode", "json" if json_mode else "markdown"),
1261
+ ("Inputs", str(len(paths))),
1262
+ ("Outputs", str(len(output_map))),
1263
+ ("Batch size", str(batch_size)),
1264
+ ("Context chars", str(context_chars)),
1265
+ ("Max retries", str(max_retries)),
1266
+ ("Workers", str(workers)),
1267
+ ("Timeout", f"{timeout:.1f}s"),
1268
+ ("Only show error", "yes" if only_show_error else "no"),
1269
+ ("In place", "yes" if in_place else "no"),
1270
+ ("Output dir", _relative_path(output_path) if output_path else "-"),
1271
+ ("Report", _relative_path(report_target) if report_target else "-"),
1272
+ ]
1273
+ _print_summary("recognize fix-mermaid (dry-run)", rows)
1274
+ return
1275
+
1276
+ progress = tqdm(total=len(paths), desc="fix-mermaid", unit="file")
1277
+ diagram_progress = tqdm(total=0, desc="diagrams", unit="diagram")
1278
+ error_records: list[dict[str, Any]] = []
1279
+
1280
+ async def run() -> MermaidFixStats:
1281
+ semaphore = asyncio.Semaphore(workers)
1282
+ progress_lock = asyncio.Lock()
1283
+ stats_total = MermaidFixStats()
1284
+
1285
+ async with httpx.AsyncClient() as client:
1286
+ async def handle_path(path: Path) -> MermaidFixStats:
1287
+ stats = MermaidFixStats()
1288
+ if json_mode:
1289
+ raw_text = read_text(path)
1290
+ items, payload, template_tag = _load_json_payload(path)
1291
+ cursor = 0
1292
+ for item_index, item in enumerate(items):
1293
+ if not isinstance(item, dict):
1294
+ continue
1295
+ template = _resolve_item_template(item, template_tag)
1296
+ fields = _template_markdown_fields(template)
1297
+ for field in fields:
1298
+ value = item.get(field)
1299
+ if not isinstance(value, str):
1300
+ continue
1301
+ spans = extract_mermaid_spans(value, context_chars)
1302
+ if spans:
1303
+ diagram_progress.total += len(spans)
1304
+ diagram_progress.refresh()
1305
+ line_start, cursor = locate_json_field_start(raw_text, value, cursor)
1306
+ field_path = f"papers[{item_index}].{field}"
1307
+ updated, errors = await fix_mermaid_text(
1308
+ value,
1309
+ str(path),
1310
+ line_start,
1311
+ field_path,
1312
+ item_index,
1313
+ provider,
1314
+ model_name,
1315
+ api_key,
1316
+ timeout,
1317
+ max_retries,
1318
+ batch_size,
1319
+ context_chars,
1320
+ client,
1321
+ stats,
1322
+ repair_enabled=not only_show_error,
1323
+ spans=spans,
1324
+ progress_cb=lambda: diagram_progress.update(1),
1325
+ )
1326
+ if not only_show_error and updated != value:
1327
+ item[field] = updated
1328
+ error_records.extend(errors)
1329
+ if not only_show_error:
1330
+ output_data: Any = items if payload is None else {**payload, "papers": items}
1331
+ output_path = output_map[path]
1332
+ serialized = json.dumps(output_data, ensure_ascii=False, indent=2)
1333
+ await asyncio.to_thread(output_path.write_text, f"{serialized}\n", encoding="utf-8")
1334
+ else:
1335
+ content = await asyncio.to_thread(read_text, path)
1336
+ spans = extract_mermaid_spans(content, context_chars)
1337
+ if spans:
1338
+ diagram_progress.total += len(spans)
1339
+ diagram_progress.refresh()
1340
+ updated, errors = await fix_mermaid_text(
1341
+ content,
1342
+ str(path),
1343
+ 1,
1344
+ None,
1345
+ None,
1346
+ provider,
1347
+ model_name,
1348
+ api_key,
1349
+ timeout,
1350
+ max_retries,
1351
+ batch_size,
1352
+ context_chars,
1353
+ client,
1354
+ stats,
1355
+ repair_enabled=not only_show_error,
1356
+ spans=spans,
1357
+ progress_cb=lambda: diagram_progress.update(1),
1358
+ )
1359
+ if not only_show_error:
1360
+ output_path = output_map[path]
1361
+ await asyncio.to_thread(output_path.write_text, updated, encoding="utf-8")
1362
+ error_records.extend(errors)
1363
+ return stats
1364
+
1365
+ async def runner(path: Path) -> None:
1366
+ async with semaphore:
1367
+ stats = await handle_path(path)
1368
+ stats_total.diagrams_total += stats.diagrams_total
1369
+ stats_total.diagrams_invalid += stats.diagrams_invalid
1370
+ stats_total.diagrams_repaired += stats.diagrams_repaired
1371
+ stats_total.diagrams_failed += stats.diagrams_failed
1372
+ async with progress_lock:
1373
+ progress.update(1)
1374
+
1375
+ await asyncio.gather(*(runner(path) for path in paths))
1376
+ return stats_total
1377
+
1378
+ try:
1379
+ stats = asyncio.run(run())
612
1380
  finally:
613
1381
  progress.close()
1382
+ diagram_progress.close()
1383
+
1384
+ if report_target and error_records:
1385
+ report_target.parent.mkdir(parents=True, exist_ok=True)
1386
+ report_target.write_text(
1387
+ json.dumps(error_records, ensure_ascii=False, indent=2) + "\n",
1388
+ encoding="utf-8",
1389
+ )
1390
+
614
1391
  rows = [
1392
+ ("Mode", "json" if json_mode else "markdown"),
615
1393
  ("Inputs", str(len(paths))),
616
- ("Outputs", str(len(output_map))),
617
- ("Fix level", fix_level),
618
- ("Format", "no" if no_format else "yes"),
619
- ("In place", "yes" if in_place else "no"),
620
- ("Output dir", _relative_path(output_path) if output_path else "-"),
621
- ("Duration", _format_duration(time.monotonic() - start_time)),
1394
+ ("Outputs", str(len(output_map) if not only_show_error else 0)),
1395
+ ("Diagrams", str(stats.diagrams_total)),
1396
+ ("Invalid", str(stats.diagrams_invalid)),
1397
+ ("Repaired", str(stats.diagrams_repaired)),
1398
+ ("Failed", str(stats.diagrams_failed)),
1399
+ ("Only show error", "yes" if only_show_error else "no"),
1400
+ ("Report", _relative_path(report_target) if report_target else "-"),
622
1401
  ]
623
- _print_summary("recognize fix", rows)
1402
+ _print_summary("recognize fix-mermaid", rows)