deepresearch-flow 0.5.1__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepresearch_flow/paper/cli.py +63 -0
- deepresearch_flow/paper/config.py +87 -12
- deepresearch_flow/paper/db.py +1041 -34
- deepresearch_flow/paper/db_ops.py +124 -19
- deepresearch_flow/paper/extract.py +1546 -152
- deepresearch_flow/paper/prompt_templates/deep_read_phi_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/deep_read_phi_user.j2 +5 -0
- deepresearch_flow/paper/prompt_templates/deep_read_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/deep_read_user.j2 +272 -40
- deepresearch_flow/paper/prompt_templates/eight_questions_phi_system.j2 +1 -0
- deepresearch_flow/paper/prompt_templates/eight_questions_phi_user.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/eight_questions_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/eight_questions_user.j2 +4 -0
- deepresearch_flow/paper/prompt_templates/simple_phi_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/simple_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/simple_user.j2 +2 -0
- deepresearch_flow/paper/providers/azure_openai.py +45 -3
- deepresearch_flow/paper/providers/openai_compatible.py +45 -3
- deepresearch_flow/paper/schemas/deep_read_phi_schema.json +1 -0
- deepresearch_flow/paper/schemas/deep_read_schema.json +1 -0
- deepresearch_flow/paper/schemas/default_paper_schema.json +6 -0
- deepresearch_flow/paper/schemas/eight_questions_schema.json +1 -0
- deepresearch_flow/paper/snapshot/__init__.py +4 -0
- deepresearch_flow/paper/snapshot/api.py +941 -0
- deepresearch_flow/paper/snapshot/builder.py +965 -0
- deepresearch_flow/paper/snapshot/identity.py +239 -0
- deepresearch_flow/paper/snapshot/schema.py +245 -0
- deepresearch_flow/paper/snapshot/tests/__init__.py +2 -0
- deepresearch_flow/paper/snapshot/tests/test_identity.py +123 -0
- deepresearch_flow/paper/snapshot/text.py +154 -0
- deepresearch_flow/paper/template_registry.py +1 -0
- deepresearch_flow/paper/templates/deep_read.md.j2 +4 -0
- deepresearch_flow/paper/templates/deep_read_phi.md.j2 +4 -0
- deepresearch_flow/paper/templates/default_paper.md.j2 +4 -0
- deepresearch_flow/paper/templates/eight_questions.md.j2 +4 -0
- deepresearch_flow/paper/web/app.py +10 -3
- deepresearch_flow/recognize/cli.py +380 -103
- deepresearch_flow/recognize/markdown.py +31 -7
- deepresearch_flow/recognize/math.py +47 -12
- deepresearch_flow/recognize/mermaid.py +320 -10
- deepresearch_flow/recognize/organize.py +29 -7
- deepresearch_flow/translator/cli.py +71 -20
- deepresearch_flow/translator/engine.py +220 -81
- deepresearch_flow/translator/prompts.py +19 -2
- deepresearch_flow/translator/protector.py +15 -3
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/METADATA +407 -33
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/RECORD +51 -43
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/WHEEL +1 -1
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/entry_points.txt +0 -0
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -42,6 +42,10 @@ from deepresearch_flow.recognize.mermaid import (
|
|
|
42
42
|
extract_mermaid_spans,
|
|
43
43
|
fix_mermaid_text,
|
|
44
44
|
require_mmdc,
|
|
45
|
+
extract_diagrams_from_text,
|
|
46
|
+
repair_all_diagrams_global,
|
|
47
|
+
DiagramTask,
|
|
48
|
+
apply_replacements,
|
|
45
49
|
)
|
|
46
50
|
from deepresearch_flow.recognize.organize import (
|
|
47
51
|
discover_mineru_dirs,
|
|
@@ -73,7 +77,8 @@ def _relative_path(path: Path) -> str:
|
|
|
73
77
|
|
|
74
78
|
def _warn_if_not_empty(output_dir: Path) -> None:
|
|
75
79
|
if output_dir.exists() and any(output_dir.iterdir()):
|
|
76
|
-
|
|
80
|
+
item_count = sum(1 for _ in output_dir.iterdir())
|
|
81
|
+
logger.warning("Output directory not empty: %s (items=%d)", output_dir, item_count)
|
|
77
82
|
|
|
78
83
|
|
|
79
84
|
def _print_summary(title: str, rows: list[tuple[str, str]]) -> None:
|
|
@@ -114,6 +119,60 @@ def _map_output_files(
|
|
|
114
119
|
return mapping
|
|
115
120
|
|
|
116
121
|
|
|
122
|
+
RetryKey = tuple[int, str | None, int | None]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _load_retry_targets(report_path: Path) -> dict[Path, set[RetryKey]]:
|
|
126
|
+
if not report_path.exists():
|
|
127
|
+
raise click.ClickException(f"Retry report not found: {report_path}")
|
|
128
|
+
try:
|
|
129
|
+
payload = json.loads(read_text(report_path))
|
|
130
|
+
except json.JSONDecodeError as exc:
|
|
131
|
+
raise click.ClickException(f"Retry report is not valid JSON: {exc}") from exc
|
|
132
|
+
if not isinstance(payload, list) or not payload:
|
|
133
|
+
raise click.ClickException(f"Retry report is empty: {report_path}")
|
|
134
|
+
targets: dict[Path, set[RetryKey]] = {}
|
|
135
|
+
for entry in payload:
|
|
136
|
+
if not isinstance(entry, dict):
|
|
137
|
+
continue
|
|
138
|
+
path_raw = entry.get("path")
|
|
139
|
+
line_raw = entry.get("line")
|
|
140
|
+
if not path_raw or line_raw is None:
|
|
141
|
+
continue
|
|
142
|
+
try:
|
|
143
|
+
line_no = int(line_raw)
|
|
144
|
+
except (TypeError, ValueError):
|
|
145
|
+
continue
|
|
146
|
+
field_path = entry.get("field_path")
|
|
147
|
+
if not isinstance(field_path, str):
|
|
148
|
+
field_path = None
|
|
149
|
+
item_index = entry.get("item_index")
|
|
150
|
+
if not isinstance(item_index, int):
|
|
151
|
+
item_index = None
|
|
152
|
+
key = (line_no, field_path, item_index)
|
|
153
|
+
targets.setdefault(Path(path_raw).resolve(), set()).add(key)
|
|
154
|
+
if not targets:
|
|
155
|
+
raise click.ClickException(f"Retry report has no valid entries: {report_path}")
|
|
156
|
+
return targets
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _filter_retry_spans(
|
|
160
|
+
spans: list[Any],
|
|
161
|
+
line_offset: int,
|
|
162
|
+
field_path: str | None,
|
|
163
|
+
item_index: int | None,
|
|
164
|
+
retry_keys: set[RetryKey] | None,
|
|
165
|
+
) -> list[Any]:
|
|
166
|
+
if not retry_keys:
|
|
167
|
+
return spans
|
|
168
|
+
filtered: list[Any] = []
|
|
169
|
+
for span in spans:
|
|
170
|
+
line_no = line_offset + span.line - 1
|
|
171
|
+
if (line_no, field_path, item_index) in retry_keys:
|
|
172
|
+
filtered.append(span)
|
|
173
|
+
return filtered
|
|
174
|
+
|
|
175
|
+
|
|
117
176
|
def _aggregate_image_counts(paths: Iterable[Path]) -> dict[str, int]:
|
|
118
177
|
totals = {"total": 0, "data": 0, "http": 0, "local": 0}
|
|
119
178
|
for path in paths:
|
|
@@ -194,6 +253,8 @@ async def _fix_json_items(
|
|
|
194
253
|
default_template: str | None,
|
|
195
254
|
fix_level: str,
|
|
196
255
|
format_enabled: bool,
|
|
256
|
+
progress: tqdm | None = None,
|
|
257
|
+
progress_lock: asyncio.Lock | None = None,
|
|
197
258
|
) -> tuple[int, int, int, int]:
|
|
198
259
|
items_total = 0
|
|
199
260
|
items_updated = 0
|
|
@@ -218,6 +279,9 @@ async def _fix_json_items(
|
|
|
218
279
|
item_updated = True
|
|
219
280
|
if item_updated:
|
|
220
281
|
items_updated += 1
|
|
282
|
+
if progress and progress_lock:
|
|
283
|
+
async with progress_lock:
|
|
284
|
+
progress.update(1)
|
|
221
285
|
return items_total, items_updated, fields_total, fields_updated
|
|
222
286
|
|
|
223
287
|
|
|
@@ -350,7 +414,12 @@ async def _run_fix_json(
|
|
|
350
414
|
async def handler(path: Path) -> tuple[int, int, int, int, int]:
|
|
351
415
|
items, payload, template_tag = _load_json_payload(path)
|
|
352
416
|
items_total, items_updated, fields_total, fields_updated = await _fix_json_items(
|
|
353
|
-
items,
|
|
417
|
+
items,
|
|
418
|
+
template_tag,
|
|
419
|
+
fix_level,
|
|
420
|
+
format_enabled,
|
|
421
|
+
progress,
|
|
422
|
+
progress_lock,
|
|
354
423
|
)
|
|
355
424
|
output_data: Any
|
|
356
425
|
if payload is None:
|
|
@@ -367,9 +436,6 @@ async def _run_fix_json(
|
|
|
367
436
|
async with semaphore:
|
|
368
437
|
result = await handler(path)
|
|
369
438
|
results.append(result)
|
|
370
|
-
if progress and progress_lock:
|
|
371
|
-
async with progress_lock:
|
|
372
|
-
progress.update(1)
|
|
373
439
|
|
|
374
440
|
await asyncio.gather(*(runner(path) for path in paths))
|
|
375
441
|
return results
|
|
@@ -787,7 +853,16 @@ def recognize_fix(
|
|
|
787
853
|
_print_summary("recognize fix (dry-run)", rows)
|
|
788
854
|
return
|
|
789
855
|
|
|
790
|
-
|
|
856
|
+
progress_total = len(paths)
|
|
857
|
+
progress_unit = "file"
|
|
858
|
+
if json_mode:
|
|
859
|
+
json_items_total = 0
|
|
860
|
+
for path in paths:
|
|
861
|
+
items, _, _ = _load_json_payload(path)
|
|
862
|
+
json_items_total += sum(1 for item in items if isinstance(item, dict))
|
|
863
|
+
progress_total = json_items_total
|
|
864
|
+
progress_unit = "item"
|
|
865
|
+
progress = tqdm(total=progress_total, desc="fix", unit=progress_unit)
|
|
791
866
|
try:
|
|
792
867
|
if json_mode:
|
|
793
868
|
results = asyncio.run(
|
|
@@ -870,6 +945,7 @@ def recognize_fix(
|
|
|
870
945
|
@click.option("--max-retries", "max_retries", default=3, show_default=True, type=int)
|
|
871
946
|
@click.option("--workers", type=int, default=4, show_default=True, help="Concurrent workers")
|
|
872
947
|
@click.option("--timeout", "timeout", default=120.0, show_default=True, type=float)
|
|
948
|
+
@click.option("--retry-failed", "retry_failed", is_flag=True, help="Retry only failed formulas")
|
|
873
949
|
@click.option(
|
|
874
950
|
"--only-show-error",
|
|
875
951
|
"only_show_error",
|
|
@@ -892,6 +968,7 @@ def recognize_fix_math(
|
|
|
892
968
|
max_retries: int,
|
|
893
969
|
workers: int,
|
|
894
970
|
timeout: float,
|
|
971
|
+
retry_failed: bool,
|
|
895
972
|
only_show_error: bool,
|
|
896
973
|
report_path: str | None,
|
|
897
974
|
dry_run: bool,
|
|
@@ -911,6 +988,8 @@ def recognize_fix_math(
|
|
|
911
988
|
raise click.ClickException("--max-retries must be non-negative")
|
|
912
989
|
if workers <= 0:
|
|
913
990
|
raise click.ClickException("--workers must be positive")
|
|
991
|
+
if retry_failed and only_show_error:
|
|
992
|
+
raise click.ClickException("--retry-failed cannot be used with --only-show-error")
|
|
914
993
|
try:
|
|
915
994
|
require_pylatexenc()
|
|
916
995
|
except RuntimeError as exc:
|
|
@@ -954,6 +1033,24 @@ def recognize_fix_math(
|
|
|
954
1033
|
return
|
|
955
1034
|
|
|
956
1035
|
output_path = Path(output_dir) if output_dir else None
|
|
1036
|
+
report_target = None
|
|
1037
|
+
if report_path:
|
|
1038
|
+
report_target = Path(report_path)
|
|
1039
|
+
elif not only_show_error:
|
|
1040
|
+
if output_path:
|
|
1041
|
+
report_target = output_path / "fix-math-errors.json"
|
|
1042
|
+
elif in_place:
|
|
1043
|
+
report_target = Path.cwd() / "fix-math-errors.json"
|
|
1044
|
+
|
|
1045
|
+
retry_targets: dict[Path, set[RetryKey]] | None = None
|
|
1046
|
+
if retry_failed:
|
|
1047
|
+
if report_target is None:
|
|
1048
|
+
raise click.ClickException("--retry-failed requires an error report path")
|
|
1049
|
+
retry_targets = _load_retry_targets(report_target)
|
|
1050
|
+
paths = [path for path in paths if path.resolve() in retry_targets]
|
|
1051
|
+
if not paths:
|
|
1052
|
+
raise click.ClickException("No failed formulas matched the provided inputs")
|
|
1053
|
+
|
|
957
1054
|
if output_path and not dry_run and not only_show_error:
|
|
958
1055
|
output_path = _ensure_output_dir(output_dir)
|
|
959
1056
|
_warn_if_not_empty(output_path)
|
|
@@ -969,15 +1066,6 @@ def recognize_fix_math(
|
|
|
969
1066
|
else:
|
|
970
1067
|
output_map = {path: path for path in paths}
|
|
971
1068
|
|
|
972
|
-
report_target = None
|
|
973
|
-
if report_path:
|
|
974
|
-
report_target = Path(report_path)
|
|
975
|
-
elif not only_show_error:
|
|
976
|
-
if output_path:
|
|
977
|
-
report_target = output_path / "fix-math-errors.json"
|
|
978
|
-
elif in_place:
|
|
979
|
-
report_target = Path.cwd() / "fix-math-errors.json"
|
|
980
|
-
|
|
981
1069
|
if dry_run and not only_show_error:
|
|
982
1070
|
rows = [
|
|
983
1071
|
("Mode", "json" if json_mode else "markdown"),
|
|
@@ -988,6 +1076,7 @@ def recognize_fix_math(
|
|
|
988
1076
|
("Max retries", str(max_retries)),
|
|
989
1077
|
("Workers", str(workers)),
|
|
990
1078
|
("Timeout", f"{timeout:.1f}s"),
|
|
1079
|
+
("Retry failed", "yes" if retry_failed else "no"),
|
|
991
1080
|
("Only show error", "yes" if only_show_error else "no"),
|
|
992
1081
|
("In place", "yes" if in_place else "no"),
|
|
993
1082
|
("Output dir", _relative_path(output_path) if output_path else "-"),
|
|
@@ -1021,12 +1110,24 @@ def recognize_fix_math(
|
|
|
1021
1110
|
value = item.get(field)
|
|
1022
1111
|
if not isinstance(value, str):
|
|
1023
1112
|
continue
|
|
1024
|
-
spans = extract_math_spans(value, context_chars)
|
|
1025
|
-
if spans:
|
|
1026
|
-
formula_progress.total += len(spans)
|
|
1027
|
-
formula_progress.refresh()
|
|
1028
1113
|
line_start, cursor = locate_json_field_start(raw_text, value, cursor)
|
|
1029
1114
|
field_path = f"papers[{item_index}].{field}"
|
|
1115
|
+
spans = extract_math_spans(value, context_chars)
|
|
1116
|
+
retry_keys = None
|
|
1117
|
+
if retry_targets is not None:
|
|
1118
|
+
retry_keys = retry_targets.get(path.resolve(), set())
|
|
1119
|
+
retry_keys = {
|
|
1120
|
+
key
|
|
1121
|
+
for key in retry_keys
|
|
1122
|
+
if key[1] == field_path and key[2] == item_index
|
|
1123
|
+
}
|
|
1124
|
+
spans = _filter_retry_spans(
|
|
1125
|
+
spans, line_start, field_path, item_index, retry_keys
|
|
1126
|
+
)
|
|
1127
|
+
if not spans:
|
|
1128
|
+
continue
|
|
1129
|
+
formula_progress.total += len(spans)
|
|
1130
|
+
formula_progress.refresh()
|
|
1030
1131
|
updated, errors = await fix_math_text(
|
|
1031
1132
|
value,
|
|
1032
1133
|
str(path),
|
|
@@ -1044,6 +1145,7 @@ def recognize_fix_math(
|
|
|
1044
1145
|
stats,
|
|
1045
1146
|
repair_enabled=not only_show_error,
|
|
1046
1147
|
spans=spans,
|
|
1148
|
+
allowed_keys=retry_keys,
|
|
1047
1149
|
progress_cb=lambda: formula_progress.update(1),
|
|
1048
1150
|
)
|
|
1049
1151
|
if not only_show_error and updated != value:
|
|
@@ -1057,6 +1159,12 @@ def recognize_fix_math(
|
|
|
1057
1159
|
else:
|
|
1058
1160
|
content = await asyncio.to_thread(read_text, path)
|
|
1059
1161
|
spans = extract_math_spans(content, context_chars)
|
|
1162
|
+
retry_keys = None
|
|
1163
|
+
if retry_targets is not None:
|
|
1164
|
+
retry_keys = retry_targets.get(path.resolve(), set())
|
|
1165
|
+
spans = _filter_retry_spans(spans, 1, None, None, retry_keys)
|
|
1166
|
+
if not spans:
|
|
1167
|
+
return stats
|
|
1060
1168
|
if spans:
|
|
1061
1169
|
formula_progress.total += len(spans)
|
|
1062
1170
|
formula_progress.refresh()
|
|
@@ -1077,6 +1185,7 @@ def recognize_fix_math(
|
|
|
1077
1185
|
stats,
|
|
1078
1186
|
repair_enabled=not only_show_error,
|
|
1079
1187
|
spans=spans,
|
|
1188
|
+
allowed_keys=retry_keys,
|
|
1080
1189
|
progress_cb=lambda: formula_progress.update(1),
|
|
1081
1190
|
)
|
|
1082
1191
|
if not only_show_error:
|
|
@@ -1121,6 +1230,7 @@ def recognize_fix_math(
|
|
|
1121
1230
|
("Cleaned", str(stats.formulas_cleaned)),
|
|
1122
1231
|
("Repaired", str(stats.formulas_repaired)),
|
|
1123
1232
|
("Failed", str(stats.formulas_failed)),
|
|
1233
|
+
("Retry failed", "yes" if retry_failed else "no"),
|
|
1124
1234
|
("Only show error", "yes" if only_show_error else "no"),
|
|
1125
1235
|
("Report", _relative_path(report_target) if report_target else "-"),
|
|
1126
1236
|
]
|
|
@@ -1147,6 +1257,7 @@ def recognize_fix_math(
|
|
|
1147
1257
|
@click.option("--max-retries", "max_retries", default=3, show_default=True, type=int)
|
|
1148
1258
|
@click.option("--workers", type=int, default=4, show_default=True, help="Concurrent workers")
|
|
1149
1259
|
@click.option("--timeout", "timeout", default=120.0, show_default=True, type=float)
|
|
1260
|
+
@click.option("--retry-failed", "retry_failed", is_flag=True, help="Retry only failed diagrams")
|
|
1150
1261
|
@click.option(
|
|
1151
1262
|
"--only-show-error",
|
|
1152
1263
|
"only_show_error",
|
|
@@ -1169,6 +1280,7 @@ def recognize_fix_mermaid(
|
|
|
1169
1280
|
max_retries: int,
|
|
1170
1281
|
workers: int,
|
|
1171
1282
|
timeout: float,
|
|
1283
|
+
retry_failed: bool,
|
|
1172
1284
|
only_show_error: bool,
|
|
1173
1285
|
report_path: str | None,
|
|
1174
1286
|
dry_run: bool,
|
|
@@ -1188,6 +1300,8 @@ def recognize_fix_mermaid(
|
|
|
1188
1300
|
raise click.ClickException("--max-retries must be non-negative")
|
|
1189
1301
|
if workers <= 0:
|
|
1190
1302
|
raise click.ClickException("--workers must be positive")
|
|
1303
|
+
if retry_failed and only_show_error:
|
|
1304
|
+
raise click.ClickException("--retry-failed cannot be used with --only-show-error")
|
|
1191
1305
|
try:
|
|
1192
1306
|
require_mmdc()
|
|
1193
1307
|
except RuntimeError as exc:
|
|
@@ -1231,6 +1345,24 @@ def recognize_fix_mermaid(
|
|
|
1231
1345
|
return
|
|
1232
1346
|
|
|
1233
1347
|
output_path = Path(output_dir) if output_dir else None
|
|
1348
|
+
report_target = None
|
|
1349
|
+
if report_path:
|
|
1350
|
+
report_target = Path(report_path)
|
|
1351
|
+
elif not only_show_error:
|
|
1352
|
+
if output_path:
|
|
1353
|
+
report_target = output_path / "fix-mermaid-errors.json"
|
|
1354
|
+
elif in_place:
|
|
1355
|
+
report_target = Path.cwd() / "fix-mermaid-errors.json"
|
|
1356
|
+
|
|
1357
|
+
retry_targets: dict[Path, set[RetryKey]] | None = None
|
|
1358
|
+
if retry_failed:
|
|
1359
|
+
if report_target is None:
|
|
1360
|
+
raise click.ClickException("--retry-failed requires an error report path")
|
|
1361
|
+
retry_targets = _load_retry_targets(report_target)
|
|
1362
|
+
paths = [path for path in paths if path.resolve() in retry_targets]
|
|
1363
|
+
if not paths:
|
|
1364
|
+
raise click.ClickException("No failed diagrams matched the provided inputs")
|
|
1365
|
+
|
|
1234
1366
|
if output_path and not dry_run and not only_show_error:
|
|
1235
1367
|
output_path = _ensure_output_dir(output_dir)
|
|
1236
1368
|
_warn_if_not_empty(output_path)
|
|
@@ -1246,15 +1378,6 @@ def recognize_fix_mermaid(
|
|
|
1246
1378
|
else:
|
|
1247
1379
|
output_map = {path: path for path in paths}
|
|
1248
1380
|
|
|
1249
|
-
report_target = None
|
|
1250
|
-
if report_path:
|
|
1251
|
-
report_target = Path(report_path)
|
|
1252
|
-
elif not only_show_error:
|
|
1253
|
-
if output_path:
|
|
1254
|
-
report_target = output_path / "fix-mermaid-errors.json"
|
|
1255
|
-
elif in_place:
|
|
1256
|
-
report_target = Path.cwd() / "fix-mermaid-errors.json"
|
|
1257
|
-
|
|
1258
1381
|
if dry_run and not only_show_error:
|
|
1259
1382
|
rows = [
|
|
1260
1383
|
("Mode", "json" if json_mode else "markdown"),
|
|
@@ -1265,6 +1388,7 @@ def recognize_fix_mermaid(
|
|
|
1265
1388
|
("Max retries", str(max_retries)),
|
|
1266
1389
|
("Workers", str(workers)),
|
|
1267
1390
|
("Timeout", f"{timeout:.1f}s"),
|
|
1391
|
+
("Retry failed", "yes" if retry_failed else "no"),
|
|
1268
1392
|
("Only show error", "yes" if only_show_error else "no"),
|
|
1269
1393
|
("In place", "yes" if in_place else "no"),
|
|
1270
1394
|
("Output dir", _relative_path(output_path) if output_path else "-"),
|
|
@@ -1273,112 +1397,260 @@ def recognize_fix_mermaid(
|
|
|
1273
1397
|
_print_summary("recognize fix-mermaid (dry-run)", rows)
|
|
1274
1398
|
return
|
|
1275
1399
|
|
|
1276
|
-
progress = tqdm(total=len(paths), desc="
|
|
1277
|
-
|
|
1400
|
+
progress = tqdm(total=len(paths), desc="extract", unit="file")
|
|
1401
|
+
field_progress = tqdm(total=0, desc="extract-field", unit="field", disable=not json_mode, leave=False)
|
|
1402
|
+
diagram_progress = tqdm(total=0, desc="repair", unit="diagram")
|
|
1278
1403
|
error_records: list[dict[str, Any]] = []
|
|
1404
|
+
|
|
1405
|
+
# Performance metrics
|
|
1406
|
+
extract_start_time = time.monotonic()
|
|
1407
|
+
repair_start_time = 0.0
|
|
1408
|
+
extract_duration = 0.0
|
|
1409
|
+
repair_duration = 0.0
|
|
1279
1410
|
|
|
1280
1411
|
async def run() -> MermaidFixStats:
|
|
1281
|
-
semaphore = asyncio.Semaphore(workers)
|
|
1282
|
-
progress_lock = asyncio.Lock()
|
|
1283
1412
|
stats_total = MermaidFixStats()
|
|
1284
1413
|
|
|
1285
1414
|
async with httpx.AsyncClient() as client:
|
|
1286
|
-
|
|
1287
|
-
|
|
1415
|
+
# Phase 1: Extract all diagrams from all files in parallel (flatten to 1D)
|
|
1416
|
+
progress_lock = asyncio.Lock()
|
|
1417
|
+
field_progress_lock = asyncio.Lock()
|
|
1418
|
+
|
|
1419
|
+
async def extract_from_file(path: Path) -> list[DiagramTask]:
|
|
1420
|
+
tasks: list[DiagramTask] = []
|
|
1421
|
+
|
|
1288
1422
|
if json_mode:
|
|
1289
|
-
raw_text = read_text
|
|
1423
|
+
raw_text = await asyncio.to_thread(read_text, path)
|
|
1290
1424
|
items, payload, template_tag = _load_json_payload(path)
|
|
1425
|
+
|
|
1426
|
+
logger.info("Extracting from JSON: %s (%d papers)", _relative_path(path), len(items))
|
|
1427
|
+
|
|
1428
|
+
# Pre-calculate all field positions for parallel extraction
|
|
1429
|
+
field_locations: list[tuple[int, str, str, str | None, int]] = []
|
|
1291
1430
|
cursor = 0
|
|
1431
|
+
|
|
1292
1432
|
for item_index, item in enumerate(items):
|
|
1293
1433
|
if not isinstance(item, dict):
|
|
1294
1434
|
continue
|
|
1295
1435
|
template = _resolve_item_template(item, template_tag)
|
|
1296
1436
|
fields = _template_markdown_fields(template)
|
|
1437
|
+
|
|
1297
1438
|
for field in fields:
|
|
1298
1439
|
value = item.get(field)
|
|
1299
1440
|
if not isinstance(value, str):
|
|
1300
1441
|
continue
|
|
1301
|
-
spans = extract_mermaid_spans(value, context_chars)
|
|
1302
|
-
if spans:
|
|
1303
|
-
diagram_progress.total += len(spans)
|
|
1304
|
-
diagram_progress.refresh()
|
|
1305
1442
|
line_start, cursor = locate_json_field_start(raw_text, value, cursor)
|
|
1306
1443
|
field_path = f"papers[{item_index}].{field}"
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1444
|
+
field_locations.append((line_start, value, field_path, None, item_index))
|
|
1445
|
+
|
|
1446
|
+
logger.info("Pre-calculated %d field locations from %s", len(field_locations), _relative_path(path))
|
|
1447
|
+
|
|
1448
|
+
# Apply retry filter to field locations if needed
|
|
1449
|
+
if retry_targets is not None:
|
|
1450
|
+
retry_keys = retry_targets.get(path.resolve(), set())
|
|
1451
|
+
# Prefer filtering by (field_path, item_index) to avoid expensive validation / mmdc calls.
|
|
1452
|
+
retry_fields = {
|
|
1453
|
+
(field_path, item_index)
|
|
1454
|
+
for _, field_path, item_index in retry_keys
|
|
1455
|
+
if field_path is not None and item_index is not None
|
|
1456
|
+
}
|
|
1457
|
+
if retry_fields:
|
|
1458
|
+
before = len(field_locations)
|
|
1459
|
+
field_locations = [
|
|
1460
|
+
loc for loc in field_locations if (loc[2], loc[4]) in retry_fields
|
|
1461
|
+
]
|
|
1462
|
+
logger.info(
|
|
1463
|
+
"Retry filter: %d/%d fields match (by field_path)",
|
|
1464
|
+
len(field_locations),
|
|
1465
|
+
before,
|
|
1325
1466
|
)
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1467
|
+
else:
|
|
1468
|
+
# Fallback: filter by line numbers using fast span extraction (no validation).
|
|
1469
|
+
filtered_locations: list[tuple[int, str, str, str | None, int]] = []
|
|
1470
|
+
for line_start, value, field_path, _, item_index in field_locations:
|
|
1471
|
+
spans = extract_mermaid_spans(value, context_chars)
|
|
1472
|
+
if any(
|
|
1473
|
+
(line_start + span.line - 1, field_path, item_index) in retry_keys
|
|
1474
|
+
for span in spans
|
|
1475
|
+
):
|
|
1476
|
+
filtered_locations.append((line_start, value, field_path, None, item_index))
|
|
1477
|
+
field_locations = filtered_locations
|
|
1478
|
+
logger.info("Retry filter: %d fields match (by line)", len(field_locations))
|
|
1479
|
+
|
|
1480
|
+
# Parallel extraction from all fields
|
|
1481
|
+
async def extract_from_field(loc: tuple[int, str, str, str | None, int]) -> list[DiagramTask]:
|
|
1482
|
+
line_start, value, field_path, _, item_index = loc
|
|
1483
|
+
field_tasks = extract_diagrams_from_text(
|
|
1484
|
+
value, path, line_start, field_path, item_index, context_chars,
|
|
1485
|
+
skip_validation=not only_show_error # Skip validation unless validating only
|
|
1486
|
+
)
|
|
1487
|
+
|
|
1488
|
+
# Apply retry filter to individual tasks
|
|
1489
|
+
if retry_targets is not None:
|
|
1490
|
+
retry_keys = retry_targets.get(path.resolve(), set())
|
|
1491
|
+
field_tasks = [
|
|
1492
|
+
task for task in field_tasks
|
|
1493
|
+
if (task.file_line_offset + task.span.line - 1, task.field_path, task.item_index) in retry_keys
|
|
1494
|
+
]
|
|
1495
|
+
|
|
1496
|
+
return field_tasks
|
|
1497
|
+
|
|
1498
|
+
if field_locations:
|
|
1499
|
+
logger.info("Extracting diagrams from %d fields in parallel...", len(field_locations))
|
|
1500
|
+
|
|
1501
|
+
async with field_progress_lock:
|
|
1502
|
+
field_progress.total += len(field_locations)
|
|
1503
|
+
field_progress.refresh()
|
|
1504
|
+
|
|
1505
|
+
# Bounded worker pool (avoid scheduling thousands of coroutines at once).
|
|
1506
|
+
max_field_workers = 50
|
|
1507
|
+
field_workers = min(max_field_workers, len(field_locations))
|
|
1508
|
+
field_queue: asyncio.Queue[tuple[int, str, str, str | None, int] | None] = asyncio.Queue()
|
|
1509
|
+
for loc in field_locations:
|
|
1510
|
+
field_queue.put_nowait(loc)
|
|
1511
|
+
for _ in range(field_workers):
|
|
1512
|
+
field_queue.put_nowait(None)
|
|
1513
|
+
|
|
1514
|
+
async def field_worker() -> list[DiagramTask]:
|
|
1515
|
+
out: list[DiagramTask] = []
|
|
1516
|
+
while True:
|
|
1517
|
+
loc = await field_queue.get()
|
|
1518
|
+
if loc is None:
|
|
1519
|
+
break
|
|
1520
|
+
out.extend(await extract_from_field(loc))
|
|
1521
|
+
async with field_progress_lock:
|
|
1522
|
+
field_progress.update(1)
|
|
1523
|
+
return out
|
|
1524
|
+
|
|
1525
|
+
worker_results = await asyncio.gather(*[field_worker() for _ in range(field_workers)])
|
|
1526
|
+
for batch in worker_results:
|
|
1527
|
+
tasks.extend(batch)
|
|
1528
|
+
|
|
1529
|
+
logger.info("Extracted %d diagrams from %s", len(tasks), _relative_path(path))
|
|
1334
1530
|
else:
|
|
1335
1531
|
content = await asyncio.to_thread(read_text, path)
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
content,
|
|
1342
|
-
|
|
1343
|
-
1,
|
|
1344
|
-
None,
|
|
1345
|
-
None,
|
|
1346
|
-
provider,
|
|
1347
|
-
model_name,
|
|
1348
|
-
api_key,
|
|
1349
|
-
timeout,
|
|
1350
|
-
max_retries,
|
|
1351
|
-
batch_size,
|
|
1352
|
-
context_chars,
|
|
1353
|
-
client,
|
|
1354
|
-
stats,
|
|
1355
|
-
repair_enabled=not only_show_error,
|
|
1356
|
-
spans=spans,
|
|
1357
|
-
progress_cb=lambda: diagram_progress.update(1),
|
|
1532
|
+
|
|
1533
|
+
logger.info("Extracting from markdown: %s", _relative_path(path))
|
|
1534
|
+
|
|
1535
|
+
# Extract diagrams from markdown
|
|
1536
|
+
file_tasks = extract_diagrams_from_text(
|
|
1537
|
+
content, path, 1, None, None, context_chars,
|
|
1538
|
+
skip_validation=not only_show_error # Skip validation unless validating only
|
|
1358
1539
|
)
|
|
1359
|
-
|
|
1360
|
-
|
|
1540
|
+
|
|
1541
|
+
# Apply retry filter if needed
|
|
1542
|
+
if retry_targets is not None:
|
|
1543
|
+
retry_keys = retry_targets.get(path.resolve(), set())
|
|
1544
|
+
file_tasks = [
|
|
1545
|
+
task for task in file_tasks
|
|
1546
|
+
if (task.file_line_offset + task.span.line - 1, task.field_path, task.item_index) in retry_keys
|
|
1547
|
+
]
|
|
1548
|
+
|
|
1549
|
+
tasks.extend(file_tasks)
|
|
1550
|
+
logger.info("Extracted %d diagrams from %s", len(tasks), _relative_path(path))
|
|
1551
|
+
|
|
1552
|
+
async with progress_lock:
|
|
1553
|
+
progress.update(1)
|
|
1554
|
+
return tasks
|
|
1555
|
+
|
|
1556
|
+
# Parallel extraction with progress
|
|
1557
|
+
file_task_lists = await asyncio.gather(*[extract_from_file(path) for path in paths])
|
|
1558
|
+
all_tasks = [task for tasks in file_task_lists for task in tasks]
|
|
1559
|
+
|
|
1560
|
+
progress.close()
|
|
1561
|
+
field_progress.close()
|
|
1562
|
+
nonlocal extract_duration, repair_start_time
|
|
1563
|
+
extract_duration = time.monotonic() - extract_start_time
|
|
1564
|
+
|
|
1565
|
+
# Update diagram progress total
|
|
1566
|
+
diagram_progress.total = len(all_tasks)
|
|
1567
|
+
diagram_progress.refresh()
|
|
1568
|
+
|
|
1569
|
+
if not all_tasks:
|
|
1570
|
+
return stats_total
|
|
1571
|
+
|
|
1572
|
+
# Phase 2: Global parallel repair (flatten all batches)
|
|
1573
|
+
repair_start_time = time.monotonic()
|
|
1574
|
+
file_replacements, errors = await repair_all_diagrams_global(
|
|
1575
|
+
all_tasks,
|
|
1576
|
+
batch_size,
|
|
1577
|
+
workers, # Use workers for global batch concurrency
|
|
1578
|
+
provider,
|
|
1579
|
+
model_name,
|
|
1580
|
+
api_key,
|
|
1581
|
+
timeout,
|
|
1582
|
+
max_retries,
|
|
1583
|
+
client,
|
|
1584
|
+
stats_total,
|
|
1585
|
+
progress_cb=lambda: diagram_progress.update(1) if not only_show_error else None,
|
|
1586
|
+
)
|
|
1587
|
+
|
|
1588
|
+
error_records.extend(errors)
|
|
1589
|
+
diagram_progress.close()
|
|
1590
|
+
nonlocal repair_duration
|
|
1591
|
+
repair_duration = time.monotonic() - repair_start_time
|
|
1592
|
+
|
|
1593
|
+
# Phase 3: Write back to files
|
|
1594
|
+
if not only_show_error:
|
|
1595
|
+
write_progress = tqdm(total=len(paths), desc="write", unit="file")
|
|
1596
|
+
|
|
1597
|
+
for path in paths:
|
|
1598
|
+
replacements = file_replacements.get(path, [])
|
|
1599
|
+
output_path = output_map[path]
|
|
1600
|
+
|
|
1601
|
+
if json_mode:
|
|
1602
|
+
# For JSON, apply replacements to fields
|
|
1603
|
+
raw_text = await asyncio.to_thread(read_text, path)
|
|
1604
|
+
items, payload, template_tag = _load_json_payload(path)
|
|
1605
|
+
cursor = 0
|
|
1606
|
+
|
|
1607
|
+
for item_index, item in enumerate(items):
|
|
1608
|
+
if not isinstance(item, dict):
|
|
1609
|
+
continue
|
|
1610
|
+
template = _resolve_item_template(item, template_tag)
|
|
1611
|
+
fields = _template_markdown_fields(template)
|
|
1612
|
+
|
|
1613
|
+
for field in fields:
|
|
1614
|
+
value = item.get(field)
|
|
1615
|
+
if not isinstance(value, str):
|
|
1616
|
+
continue
|
|
1617
|
+
field_path = f"papers[{item_index}].{field}"
|
|
1618
|
+
|
|
1619
|
+
# Find replacements for this specific field
|
|
1620
|
+
field_replacements = [
|
|
1621
|
+
(start, end, repl)
|
|
1622
|
+
for start, end, repl in replacements
|
|
1623
|
+
if any(
|
|
1624
|
+
t.field_path == field_path and t.item_index == item_index and t.span.start == start
|
|
1625
|
+
for t in all_tasks
|
|
1626
|
+
if t.file_path == path
|
|
1627
|
+
)
|
|
1628
|
+
]
|
|
1629
|
+
|
|
1630
|
+
if field_replacements:
|
|
1631
|
+
updated_value = apply_replacements(value, field_replacements)
|
|
1632
|
+
item[field] = updated_value
|
|
1633
|
+
|
|
1634
|
+
output_data: Any = items if payload is None else {**payload, "papers": items}
|
|
1635
|
+
serialized = json.dumps(output_data, ensure_ascii=False, indent=2)
|
|
1636
|
+
await asyncio.to_thread(output_path.write_text, f"{serialized}\n", encoding="utf-8")
|
|
1637
|
+
else:
|
|
1638
|
+
# For markdown, apply replacements directly
|
|
1639
|
+
content = await asyncio.to_thread(read_text, path)
|
|
1640
|
+
updated = apply_replacements(content, replacements)
|
|
1361
1641
|
await asyncio.to_thread(output_path.write_text, updated, encoding="utf-8")
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
stats = await handle_path(path)
|
|
1368
|
-
stats_total.diagrams_total += stats.diagrams_total
|
|
1369
|
-
stats_total.diagrams_invalid += stats.diagrams_invalid
|
|
1370
|
-
stats_total.diagrams_repaired += stats.diagrams_repaired
|
|
1371
|
-
stats_total.diagrams_failed += stats.diagrams_failed
|
|
1372
|
-
async with progress_lock:
|
|
1373
|
-
progress.update(1)
|
|
1374
|
-
|
|
1375
|
-
await asyncio.gather(*(runner(path) for path in paths))
|
|
1642
|
+
|
|
1643
|
+
write_progress.update(1)
|
|
1644
|
+
|
|
1645
|
+
write_progress.close()
|
|
1646
|
+
|
|
1376
1647
|
return stats_total
|
|
1377
1648
|
|
|
1378
1649
|
try:
|
|
1379
1650
|
stats = asyncio.run(run())
|
|
1380
1651
|
finally:
|
|
1381
1652
|
progress.close()
|
|
1653
|
+
field_progress.close()
|
|
1382
1654
|
diagram_progress.close()
|
|
1383
1655
|
|
|
1384
1656
|
if report_target and error_records:
|
|
@@ -1396,6 +1668,11 @@ def recognize_fix_mermaid(
|
|
|
1396
1668
|
("Invalid", str(stats.diagrams_invalid)),
|
|
1397
1669
|
("Repaired", str(stats.diagrams_repaired)),
|
|
1398
1670
|
("Failed", str(stats.diagrams_failed)),
|
|
1671
|
+
("Extract time", _format_duration(extract_duration)),
|
|
1672
|
+
("Extract avg", f"{extract_duration / stats.diagrams_total:.3f}s/diagram" if stats.diagrams_total > 0 else "-"),
|
|
1673
|
+
("Repair time", _format_duration(repair_duration)),
|
|
1674
|
+
("Repair avg", f"{repair_duration / stats.diagrams_invalid:.3f}s/diagram" if stats.diagrams_invalid > 0 else "-"),
|
|
1675
|
+
("Retry failed", "yes" if retry_failed else "no"),
|
|
1399
1676
|
("Only show error", "yes" if only_show_error else "no"),
|
|
1400
1677
|
("Report", _relative_path(report_target) if report_target else "-"),
|
|
1401
1678
|
]
|