deepresearch-flow 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. deepresearch_flow/paper/cli.py +63 -0
  2. deepresearch_flow/paper/config.py +87 -12
  3. deepresearch_flow/paper/db.py +1041 -34
  4. deepresearch_flow/paper/db_ops.py +145 -26
  5. deepresearch_flow/paper/extract.py +1546 -152
  6. deepresearch_flow/paper/prompt_templates/deep_read_phi_system.j2 +8 -0
  7. deepresearch_flow/paper/prompt_templates/deep_read_phi_user.j2 +396 -0
  8. deepresearch_flow/paper/prompt_templates/deep_read_system.j2 +2 -0
  9. deepresearch_flow/paper/prompt_templates/deep_read_user.j2 +272 -40
  10. deepresearch_flow/paper/prompt_templates/eight_questions_phi_system.j2 +7 -0
  11. deepresearch_flow/paper/prompt_templates/eight_questions_phi_user.j2 +135 -0
  12. deepresearch_flow/paper/prompt_templates/eight_questions_system.j2 +2 -0
  13. deepresearch_flow/paper/prompt_templates/eight_questions_user.j2 +4 -0
  14. deepresearch_flow/paper/prompt_templates/simple_phi_system.j2 +8 -0
  15. deepresearch_flow/paper/prompt_templates/simple_phi_user.j2 +31 -0
  16. deepresearch_flow/paper/prompt_templates/simple_system.j2 +2 -0
  17. deepresearch_flow/paper/prompt_templates/simple_user.j2 +2 -0
  18. deepresearch_flow/paper/providers/azure_openai.py +45 -3
  19. deepresearch_flow/paper/providers/openai_compatible.py +45 -3
  20. deepresearch_flow/paper/schemas/deep_read_phi_schema.json +31 -0
  21. deepresearch_flow/paper/schemas/deep_read_schema.json +1 -0
  22. deepresearch_flow/paper/schemas/default_paper_schema.json +6 -0
  23. deepresearch_flow/paper/schemas/eight_questions_schema.json +1 -0
  24. deepresearch_flow/paper/snapshot/__init__.py +4 -0
  25. deepresearch_flow/paper/snapshot/api.py +941 -0
  26. deepresearch_flow/paper/snapshot/builder.py +965 -0
  27. deepresearch_flow/paper/snapshot/identity.py +239 -0
  28. deepresearch_flow/paper/snapshot/schema.py +245 -0
  29. deepresearch_flow/paper/snapshot/tests/__init__.py +2 -0
  30. deepresearch_flow/paper/snapshot/tests/test_identity.py +123 -0
  31. deepresearch_flow/paper/snapshot/text.py +154 -0
  32. deepresearch_flow/paper/template_registry.py +40 -0
  33. deepresearch_flow/paper/templates/deep_read.md.j2 +4 -0
  34. deepresearch_flow/paper/templates/deep_read_phi.md.j2 +44 -0
  35. deepresearch_flow/paper/templates/default_paper.md.j2 +4 -0
  36. deepresearch_flow/paper/templates/eight_questions.md.j2 +4 -0
  37. deepresearch_flow/paper/web/app.py +10 -3
  38. deepresearch_flow/paper/web/markdown.py +174 -8
  39. deepresearch_flow/paper/web/static/css/main.css +8 -1
  40. deepresearch_flow/paper/web/static/js/detail.js +46 -12
  41. deepresearch_flow/paper/web/templates/detail.html +9 -0
  42. deepresearch_flow/paper/web/text.py +8 -4
  43. deepresearch_flow/recognize/cli.py +380 -103
  44. deepresearch_flow/recognize/markdown.py +31 -7
  45. deepresearch_flow/recognize/math.py +47 -12
  46. deepresearch_flow/recognize/mermaid.py +320 -10
  47. deepresearch_flow/recognize/organize.py +35 -16
  48. deepresearch_flow/translator/cli.py +71 -20
  49. deepresearch_flow/translator/engine.py +220 -81
  50. deepresearch_flow/translator/fixers.py +15 -0
  51. deepresearch_flow/translator/prompts.py +19 -2
  52. deepresearch_flow/translator/protector.py +15 -3
  53. {deepresearch_flow-0.5.0.dist-info → deepresearch_flow-0.6.0.dist-info}/METADATA +407 -33
  54. {deepresearch_flow-0.5.0.dist-info → deepresearch_flow-0.6.0.dist-info}/RECORD +58 -42
  55. {deepresearch_flow-0.5.0.dist-info → deepresearch_flow-0.6.0.dist-info}/WHEEL +1 -1
  56. {deepresearch_flow-0.5.0.dist-info → deepresearch_flow-0.6.0.dist-info}/entry_points.txt +0 -0
  57. {deepresearch_flow-0.5.0.dist-info → deepresearch_flow-0.6.0.dist-info}/licenses/LICENSE +0 -0
  58. {deepresearch_flow-0.5.0.dist-info → deepresearch_flow-0.6.0.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,10 @@ from deepresearch_flow.recognize.mermaid import (
42
42
  extract_mermaid_spans,
43
43
  fix_mermaid_text,
44
44
  require_mmdc,
45
+ extract_diagrams_from_text,
46
+ repair_all_diagrams_global,
47
+ DiagramTask,
48
+ apply_replacements,
45
49
  )
46
50
  from deepresearch_flow.recognize.organize import (
47
51
  discover_mineru_dirs,
@@ -73,7 +77,8 @@ def _relative_path(path: Path) -> str:
73
77
 
74
78
  def _warn_if_not_empty(output_dir: Path) -> None:
75
79
  if output_dir.exists() and any(output_dir.iterdir()):
76
- logger.warning("Output directory not empty: %s", output_dir)
80
+ item_count = sum(1 for _ in output_dir.iterdir())
81
+ logger.warning("Output directory not empty: %s (items=%d)", output_dir, item_count)
77
82
 
78
83
 
79
84
  def _print_summary(title: str, rows: list[tuple[str, str]]) -> None:
@@ -114,6 +119,60 @@ def _map_output_files(
114
119
  return mapping
115
120
 
116
121
 
122
+ RetryKey = tuple[int, str | None, int | None]
123
+
124
+
125
+ def _load_retry_targets(report_path: Path) -> dict[Path, set[RetryKey]]:
126
+ if not report_path.exists():
127
+ raise click.ClickException(f"Retry report not found: {report_path}")
128
+ try:
129
+ payload = json.loads(read_text(report_path))
130
+ except json.JSONDecodeError as exc:
131
+ raise click.ClickException(f"Retry report is not valid JSON: {exc}") from exc
132
+ if not isinstance(payload, list) or not payload:
133
+ raise click.ClickException(f"Retry report is empty: {report_path}")
134
+ targets: dict[Path, set[RetryKey]] = {}
135
+ for entry in payload:
136
+ if not isinstance(entry, dict):
137
+ continue
138
+ path_raw = entry.get("path")
139
+ line_raw = entry.get("line")
140
+ if not path_raw or line_raw is None:
141
+ continue
142
+ try:
143
+ line_no = int(line_raw)
144
+ except (TypeError, ValueError):
145
+ continue
146
+ field_path = entry.get("field_path")
147
+ if not isinstance(field_path, str):
148
+ field_path = None
149
+ item_index = entry.get("item_index")
150
+ if not isinstance(item_index, int):
151
+ item_index = None
152
+ key = (line_no, field_path, item_index)
153
+ targets.setdefault(Path(path_raw).resolve(), set()).add(key)
154
+ if not targets:
155
+ raise click.ClickException(f"Retry report has no valid entries: {report_path}")
156
+ return targets
157
+
158
+
159
+ def _filter_retry_spans(
160
+ spans: list[Any],
161
+ line_offset: int,
162
+ field_path: str | None,
163
+ item_index: int | None,
164
+ retry_keys: set[RetryKey] | None,
165
+ ) -> list[Any]:
166
+ if not retry_keys:
167
+ return spans
168
+ filtered: list[Any] = []
169
+ for span in spans:
170
+ line_no = line_offset + span.line - 1
171
+ if (line_no, field_path, item_index) in retry_keys:
172
+ filtered.append(span)
173
+ return filtered
174
+
175
+
117
176
  def _aggregate_image_counts(paths: Iterable[Path]) -> dict[str, int]:
118
177
  totals = {"total": 0, "data": 0, "http": 0, "local": 0}
119
178
  for path in paths:
@@ -194,6 +253,8 @@ async def _fix_json_items(
194
253
  default_template: str | None,
195
254
  fix_level: str,
196
255
  format_enabled: bool,
256
+ progress: tqdm | None = None,
257
+ progress_lock: asyncio.Lock | None = None,
197
258
  ) -> tuple[int, int, int, int]:
198
259
  items_total = 0
199
260
  items_updated = 0
@@ -218,6 +279,9 @@ async def _fix_json_items(
218
279
  item_updated = True
219
280
  if item_updated:
220
281
  items_updated += 1
282
+ if progress and progress_lock:
283
+ async with progress_lock:
284
+ progress.update(1)
221
285
  return items_total, items_updated, fields_total, fields_updated
222
286
 
223
287
 
@@ -350,7 +414,12 @@ async def _run_fix_json(
350
414
  async def handler(path: Path) -> tuple[int, int, int, int, int]:
351
415
  items, payload, template_tag = _load_json_payload(path)
352
416
  items_total, items_updated, fields_total, fields_updated = await _fix_json_items(
353
- items, template_tag, fix_level, format_enabled
417
+ items,
418
+ template_tag,
419
+ fix_level,
420
+ format_enabled,
421
+ progress,
422
+ progress_lock,
354
423
  )
355
424
  output_data: Any
356
425
  if payload is None:
@@ -367,9 +436,6 @@ async def _run_fix_json(
367
436
  async with semaphore:
368
437
  result = await handler(path)
369
438
  results.append(result)
370
- if progress and progress_lock:
371
- async with progress_lock:
372
- progress.update(1)
373
439
 
374
440
  await asyncio.gather(*(runner(path) for path in paths))
375
441
  return results
@@ -787,7 +853,16 @@ def recognize_fix(
787
853
  _print_summary("recognize fix (dry-run)", rows)
788
854
  return
789
855
 
790
- progress = tqdm(total=len(paths), desc="fix", unit="file")
856
+ progress_total = len(paths)
857
+ progress_unit = "file"
858
+ if json_mode:
859
+ json_items_total = 0
860
+ for path in paths:
861
+ items, _, _ = _load_json_payload(path)
862
+ json_items_total += sum(1 for item in items if isinstance(item, dict))
863
+ progress_total = json_items_total
864
+ progress_unit = "item"
865
+ progress = tqdm(total=progress_total, desc="fix", unit=progress_unit)
791
866
  try:
792
867
  if json_mode:
793
868
  results = asyncio.run(
@@ -870,6 +945,7 @@ def recognize_fix(
870
945
  @click.option("--max-retries", "max_retries", default=3, show_default=True, type=int)
871
946
  @click.option("--workers", type=int, default=4, show_default=True, help="Concurrent workers")
872
947
  @click.option("--timeout", "timeout", default=120.0, show_default=True, type=float)
948
+ @click.option("--retry-failed", "retry_failed", is_flag=True, help="Retry only failed formulas")
873
949
  @click.option(
874
950
  "--only-show-error",
875
951
  "only_show_error",
@@ -892,6 +968,7 @@ def recognize_fix_math(
892
968
  max_retries: int,
893
969
  workers: int,
894
970
  timeout: float,
971
+ retry_failed: bool,
895
972
  only_show_error: bool,
896
973
  report_path: str | None,
897
974
  dry_run: bool,
@@ -911,6 +988,8 @@ def recognize_fix_math(
911
988
  raise click.ClickException("--max-retries must be non-negative")
912
989
  if workers <= 0:
913
990
  raise click.ClickException("--workers must be positive")
991
+ if retry_failed and only_show_error:
992
+ raise click.ClickException("--retry-failed cannot be used with --only-show-error")
914
993
  try:
915
994
  require_pylatexenc()
916
995
  except RuntimeError as exc:
@@ -954,6 +1033,24 @@ def recognize_fix_math(
954
1033
  return
955
1034
 
956
1035
  output_path = Path(output_dir) if output_dir else None
1036
+ report_target = None
1037
+ if report_path:
1038
+ report_target = Path(report_path)
1039
+ elif not only_show_error:
1040
+ if output_path:
1041
+ report_target = output_path / "fix-math-errors.json"
1042
+ elif in_place:
1043
+ report_target = Path.cwd() / "fix-math-errors.json"
1044
+
1045
+ retry_targets: dict[Path, set[RetryKey]] | None = None
1046
+ if retry_failed:
1047
+ if report_target is None:
1048
+ raise click.ClickException("--retry-failed requires an error report path")
1049
+ retry_targets = _load_retry_targets(report_target)
1050
+ paths = [path for path in paths if path.resolve() in retry_targets]
1051
+ if not paths:
1052
+ raise click.ClickException("No failed formulas matched the provided inputs")
1053
+
957
1054
  if output_path and not dry_run and not only_show_error:
958
1055
  output_path = _ensure_output_dir(output_dir)
959
1056
  _warn_if_not_empty(output_path)
@@ -969,15 +1066,6 @@ def recognize_fix_math(
969
1066
  else:
970
1067
  output_map = {path: path for path in paths}
971
1068
 
972
- report_target = None
973
- if report_path:
974
- report_target = Path(report_path)
975
- elif not only_show_error:
976
- if output_path:
977
- report_target = output_path / "fix-math-errors.json"
978
- elif in_place:
979
- report_target = Path.cwd() / "fix-math-errors.json"
980
-
981
1069
  if dry_run and not only_show_error:
982
1070
  rows = [
983
1071
  ("Mode", "json" if json_mode else "markdown"),
@@ -988,6 +1076,7 @@ def recognize_fix_math(
988
1076
  ("Max retries", str(max_retries)),
989
1077
  ("Workers", str(workers)),
990
1078
  ("Timeout", f"{timeout:.1f}s"),
1079
+ ("Retry failed", "yes" if retry_failed else "no"),
991
1080
  ("Only show error", "yes" if only_show_error else "no"),
992
1081
  ("In place", "yes" if in_place else "no"),
993
1082
  ("Output dir", _relative_path(output_path) if output_path else "-"),
@@ -1021,12 +1110,24 @@ def recognize_fix_math(
1021
1110
  value = item.get(field)
1022
1111
  if not isinstance(value, str):
1023
1112
  continue
1024
- spans = extract_math_spans(value, context_chars)
1025
- if spans:
1026
- formula_progress.total += len(spans)
1027
- formula_progress.refresh()
1028
1113
  line_start, cursor = locate_json_field_start(raw_text, value, cursor)
1029
1114
  field_path = f"papers[{item_index}].{field}"
1115
+ spans = extract_math_spans(value, context_chars)
1116
+ retry_keys = None
1117
+ if retry_targets is not None:
1118
+ retry_keys = retry_targets.get(path.resolve(), set())
1119
+ retry_keys = {
1120
+ key
1121
+ for key in retry_keys
1122
+ if key[1] == field_path and key[2] == item_index
1123
+ }
1124
+ spans = _filter_retry_spans(
1125
+ spans, line_start, field_path, item_index, retry_keys
1126
+ )
1127
+ if not spans:
1128
+ continue
1129
+ formula_progress.total += len(spans)
1130
+ formula_progress.refresh()
1030
1131
  updated, errors = await fix_math_text(
1031
1132
  value,
1032
1133
  str(path),
@@ -1044,6 +1145,7 @@ def recognize_fix_math(
1044
1145
  stats,
1045
1146
  repair_enabled=not only_show_error,
1046
1147
  spans=spans,
1148
+ allowed_keys=retry_keys,
1047
1149
  progress_cb=lambda: formula_progress.update(1),
1048
1150
  )
1049
1151
  if not only_show_error and updated != value:
@@ -1057,6 +1159,12 @@ def recognize_fix_math(
1057
1159
  else:
1058
1160
  content = await asyncio.to_thread(read_text, path)
1059
1161
  spans = extract_math_spans(content, context_chars)
1162
+ retry_keys = None
1163
+ if retry_targets is not None:
1164
+ retry_keys = retry_targets.get(path.resolve(), set())
1165
+ spans = _filter_retry_spans(spans, 1, None, None, retry_keys)
1166
+ if not spans:
1167
+ return stats
1060
1168
  if spans:
1061
1169
  formula_progress.total += len(spans)
1062
1170
  formula_progress.refresh()
@@ -1077,6 +1185,7 @@ def recognize_fix_math(
1077
1185
  stats,
1078
1186
  repair_enabled=not only_show_error,
1079
1187
  spans=spans,
1188
+ allowed_keys=retry_keys,
1080
1189
  progress_cb=lambda: formula_progress.update(1),
1081
1190
  )
1082
1191
  if not only_show_error:
@@ -1121,6 +1230,7 @@ def recognize_fix_math(
1121
1230
  ("Cleaned", str(stats.formulas_cleaned)),
1122
1231
  ("Repaired", str(stats.formulas_repaired)),
1123
1232
  ("Failed", str(stats.formulas_failed)),
1233
+ ("Retry failed", "yes" if retry_failed else "no"),
1124
1234
  ("Only show error", "yes" if only_show_error else "no"),
1125
1235
  ("Report", _relative_path(report_target) if report_target else "-"),
1126
1236
  ]
@@ -1147,6 +1257,7 @@ def recognize_fix_math(
1147
1257
  @click.option("--max-retries", "max_retries", default=3, show_default=True, type=int)
1148
1258
  @click.option("--workers", type=int, default=4, show_default=True, help="Concurrent workers")
1149
1259
  @click.option("--timeout", "timeout", default=120.0, show_default=True, type=float)
1260
+ @click.option("--retry-failed", "retry_failed", is_flag=True, help="Retry only failed diagrams")
1150
1261
  @click.option(
1151
1262
  "--only-show-error",
1152
1263
  "only_show_error",
@@ -1169,6 +1280,7 @@ def recognize_fix_mermaid(
1169
1280
  max_retries: int,
1170
1281
  workers: int,
1171
1282
  timeout: float,
1283
+ retry_failed: bool,
1172
1284
  only_show_error: bool,
1173
1285
  report_path: str | None,
1174
1286
  dry_run: bool,
@@ -1188,6 +1300,8 @@ def recognize_fix_mermaid(
1188
1300
  raise click.ClickException("--max-retries must be non-negative")
1189
1301
  if workers <= 0:
1190
1302
  raise click.ClickException("--workers must be positive")
1303
+ if retry_failed and only_show_error:
1304
+ raise click.ClickException("--retry-failed cannot be used with --only-show-error")
1191
1305
  try:
1192
1306
  require_mmdc()
1193
1307
  except RuntimeError as exc:
@@ -1231,6 +1345,24 @@ def recognize_fix_mermaid(
1231
1345
  return
1232
1346
 
1233
1347
  output_path = Path(output_dir) if output_dir else None
1348
+ report_target = None
1349
+ if report_path:
1350
+ report_target = Path(report_path)
1351
+ elif not only_show_error:
1352
+ if output_path:
1353
+ report_target = output_path / "fix-mermaid-errors.json"
1354
+ elif in_place:
1355
+ report_target = Path.cwd() / "fix-mermaid-errors.json"
1356
+
1357
+ retry_targets: dict[Path, set[RetryKey]] | None = None
1358
+ if retry_failed:
1359
+ if report_target is None:
1360
+ raise click.ClickException("--retry-failed requires an error report path")
1361
+ retry_targets = _load_retry_targets(report_target)
1362
+ paths = [path for path in paths if path.resolve() in retry_targets]
1363
+ if not paths:
1364
+ raise click.ClickException("No failed diagrams matched the provided inputs")
1365
+
1234
1366
  if output_path and not dry_run and not only_show_error:
1235
1367
  output_path = _ensure_output_dir(output_dir)
1236
1368
  _warn_if_not_empty(output_path)
@@ -1246,15 +1378,6 @@ def recognize_fix_mermaid(
1246
1378
  else:
1247
1379
  output_map = {path: path for path in paths}
1248
1380
 
1249
- report_target = None
1250
- if report_path:
1251
- report_target = Path(report_path)
1252
- elif not only_show_error:
1253
- if output_path:
1254
- report_target = output_path / "fix-mermaid-errors.json"
1255
- elif in_place:
1256
- report_target = Path.cwd() / "fix-mermaid-errors.json"
1257
-
1258
1381
  if dry_run and not only_show_error:
1259
1382
  rows = [
1260
1383
  ("Mode", "json" if json_mode else "markdown"),
@@ -1265,6 +1388,7 @@ def recognize_fix_mermaid(
1265
1388
  ("Max retries", str(max_retries)),
1266
1389
  ("Workers", str(workers)),
1267
1390
  ("Timeout", f"{timeout:.1f}s"),
1391
+ ("Retry failed", "yes" if retry_failed else "no"),
1268
1392
  ("Only show error", "yes" if only_show_error else "no"),
1269
1393
  ("In place", "yes" if in_place else "no"),
1270
1394
  ("Output dir", _relative_path(output_path) if output_path else "-"),
@@ -1273,112 +1397,260 @@ def recognize_fix_mermaid(
1273
1397
  _print_summary("recognize fix-mermaid (dry-run)", rows)
1274
1398
  return
1275
1399
 
1276
- progress = tqdm(total=len(paths), desc="fix-mermaid", unit="file")
1277
- diagram_progress = tqdm(total=0, desc="diagrams", unit="diagram")
1400
+ progress = tqdm(total=len(paths), desc="extract", unit="file")
1401
+ field_progress = tqdm(total=0, desc="extract-field", unit="field", disable=not json_mode, leave=False)
1402
+ diagram_progress = tqdm(total=0, desc="repair", unit="diagram")
1278
1403
  error_records: list[dict[str, Any]] = []
1404
+
1405
+ # Performance metrics
1406
+ extract_start_time = time.monotonic()
1407
+ repair_start_time = 0.0
1408
+ extract_duration = 0.0
1409
+ repair_duration = 0.0
1279
1410
 
1280
1411
  async def run() -> MermaidFixStats:
1281
- semaphore = asyncio.Semaphore(workers)
1282
- progress_lock = asyncio.Lock()
1283
1412
  stats_total = MermaidFixStats()
1284
1413
 
1285
1414
  async with httpx.AsyncClient() as client:
1286
- async def handle_path(path: Path) -> MermaidFixStats:
1287
- stats = MermaidFixStats()
1415
+ # Phase 1: Extract all diagrams from all files in parallel (flatten to 1D)
1416
+ progress_lock = asyncio.Lock()
1417
+ field_progress_lock = asyncio.Lock()
1418
+
1419
+ async def extract_from_file(path: Path) -> list[DiagramTask]:
1420
+ tasks: list[DiagramTask] = []
1421
+
1288
1422
  if json_mode:
1289
- raw_text = read_text(path)
1423
+ raw_text = await asyncio.to_thread(read_text, path)
1290
1424
  items, payload, template_tag = _load_json_payload(path)
1425
+
1426
+ logger.info("Extracting from JSON: %s (%d papers)", _relative_path(path), len(items))
1427
+
1428
+ # Pre-calculate all field positions for parallel extraction
1429
+ field_locations: list[tuple[int, str, str, str | None, int]] = []
1291
1430
  cursor = 0
1431
+
1292
1432
  for item_index, item in enumerate(items):
1293
1433
  if not isinstance(item, dict):
1294
1434
  continue
1295
1435
  template = _resolve_item_template(item, template_tag)
1296
1436
  fields = _template_markdown_fields(template)
1437
+
1297
1438
  for field in fields:
1298
1439
  value = item.get(field)
1299
1440
  if not isinstance(value, str):
1300
1441
  continue
1301
- spans = extract_mermaid_spans(value, context_chars)
1302
- if spans:
1303
- diagram_progress.total += len(spans)
1304
- diagram_progress.refresh()
1305
1442
  line_start, cursor = locate_json_field_start(raw_text, value, cursor)
1306
1443
  field_path = f"papers[{item_index}].{field}"
1307
- updated, errors = await fix_mermaid_text(
1308
- value,
1309
- str(path),
1310
- line_start,
1311
- field_path,
1312
- item_index,
1313
- provider,
1314
- model_name,
1315
- api_key,
1316
- timeout,
1317
- max_retries,
1318
- batch_size,
1319
- context_chars,
1320
- client,
1321
- stats,
1322
- repair_enabled=not only_show_error,
1323
- spans=spans,
1324
- progress_cb=lambda: diagram_progress.update(1),
1444
+ field_locations.append((line_start, value, field_path, None, item_index))
1445
+
1446
+ logger.info("Pre-calculated %d field locations from %s", len(field_locations), _relative_path(path))
1447
+
1448
+ # Apply retry filter to field locations if needed
1449
+ if retry_targets is not None:
1450
+ retry_keys = retry_targets.get(path.resolve(), set())
1451
+ # Prefer filtering by (field_path, item_index) to avoid expensive validation / mmdc calls.
1452
+ retry_fields = {
1453
+ (field_path, item_index)
1454
+ for _, field_path, item_index in retry_keys
1455
+ if field_path is not None and item_index is not None
1456
+ }
1457
+ if retry_fields:
1458
+ before = len(field_locations)
1459
+ field_locations = [
1460
+ loc for loc in field_locations if (loc[2], loc[4]) in retry_fields
1461
+ ]
1462
+ logger.info(
1463
+ "Retry filter: %d/%d fields match (by field_path)",
1464
+ len(field_locations),
1465
+ before,
1325
1466
  )
1326
- if not only_show_error and updated != value:
1327
- item[field] = updated
1328
- error_records.extend(errors)
1329
- if not only_show_error:
1330
- output_data: Any = items if payload is None else {**payload, "papers": items}
1331
- output_path = output_map[path]
1332
- serialized = json.dumps(output_data, ensure_ascii=False, indent=2)
1333
- await asyncio.to_thread(output_path.write_text, f"{serialized}\n", encoding="utf-8")
1467
+ else:
1468
+ # Fallback: filter by line numbers using fast span extraction (no validation).
1469
+ filtered_locations: list[tuple[int, str, str, str | None, int]] = []
1470
+ for line_start, value, field_path, _, item_index in field_locations:
1471
+ spans = extract_mermaid_spans(value, context_chars)
1472
+ if any(
1473
+ (line_start + span.line - 1, field_path, item_index) in retry_keys
1474
+ for span in spans
1475
+ ):
1476
+ filtered_locations.append((line_start, value, field_path, None, item_index))
1477
+ field_locations = filtered_locations
1478
+ logger.info("Retry filter: %d fields match (by line)", len(field_locations))
1479
+
1480
+ # Parallel extraction from all fields
1481
+ async def extract_from_field(loc: tuple[int, str, str, str | None, int]) -> list[DiagramTask]:
1482
+ line_start, value, field_path, _, item_index = loc
1483
+ field_tasks = extract_diagrams_from_text(
1484
+ value, path, line_start, field_path, item_index, context_chars,
1485
+ skip_validation=not only_show_error # Skip validation unless validating only
1486
+ )
1487
+
1488
+ # Apply retry filter to individual tasks
1489
+ if retry_targets is not None:
1490
+ retry_keys = retry_targets.get(path.resolve(), set())
1491
+ field_tasks = [
1492
+ task for task in field_tasks
1493
+ if (task.file_line_offset + task.span.line - 1, task.field_path, task.item_index) in retry_keys
1494
+ ]
1495
+
1496
+ return field_tasks
1497
+
1498
+ if field_locations:
1499
+ logger.info("Extracting diagrams from %d fields in parallel...", len(field_locations))
1500
+
1501
+ async with field_progress_lock:
1502
+ field_progress.total += len(field_locations)
1503
+ field_progress.refresh()
1504
+
1505
+ # Bounded worker pool (avoid scheduling thousands of coroutines at once).
1506
+ max_field_workers = 50
1507
+ field_workers = min(max_field_workers, len(field_locations))
1508
+ field_queue: asyncio.Queue[tuple[int, str, str, str | None, int] | None] = asyncio.Queue()
1509
+ for loc in field_locations:
1510
+ field_queue.put_nowait(loc)
1511
+ for _ in range(field_workers):
1512
+ field_queue.put_nowait(None)
1513
+
1514
+ async def field_worker() -> list[DiagramTask]:
1515
+ out: list[DiagramTask] = []
1516
+ while True:
1517
+ loc = await field_queue.get()
1518
+ if loc is None:
1519
+ break
1520
+ out.extend(await extract_from_field(loc))
1521
+ async with field_progress_lock:
1522
+ field_progress.update(1)
1523
+ return out
1524
+
1525
+ worker_results = await asyncio.gather(*[field_worker() for _ in range(field_workers)])
1526
+ for batch in worker_results:
1527
+ tasks.extend(batch)
1528
+
1529
+ logger.info("Extracted %d diagrams from %s", len(tasks), _relative_path(path))
1334
1530
  else:
1335
1531
  content = await asyncio.to_thread(read_text, path)
1336
- spans = extract_mermaid_spans(content, context_chars)
1337
- if spans:
1338
- diagram_progress.total += len(spans)
1339
- diagram_progress.refresh()
1340
- updated, errors = await fix_mermaid_text(
1341
- content,
1342
- str(path),
1343
- 1,
1344
- None,
1345
- None,
1346
- provider,
1347
- model_name,
1348
- api_key,
1349
- timeout,
1350
- max_retries,
1351
- batch_size,
1352
- context_chars,
1353
- client,
1354
- stats,
1355
- repair_enabled=not only_show_error,
1356
- spans=spans,
1357
- progress_cb=lambda: diagram_progress.update(1),
1532
+
1533
+ logger.info("Extracting from markdown: %s", _relative_path(path))
1534
+
1535
+ # Extract diagrams from markdown
1536
+ file_tasks = extract_diagrams_from_text(
1537
+ content, path, 1, None, None, context_chars,
1538
+ skip_validation=not only_show_error # Skip validation unless validating only
1358
1539
  )
1359
- if not only_show_error:
1360
- output_path = output_map[path]
1540
+
1541
+ # Apply retry filter if needed
1542
+ if retry_targets is not None:
1543
+ retry_keys = retry_targets.get(path.resolve(), set())
1544
+ file_tasks = [
1545
+ task for task in file_tasks
1546
+ if (task.file_line_offset + task.span.line - 1, task.field_path, task.item_index) in retry_keys
1547
+ ]
1548
+
1549
+ tasks.extend(file_tasks)
1550
+ logger.info("Extracted %d diagrams from %s", len(tasks), _relative_path(path))
1551
+
1552
+ async with progress_lock:
1553
+ progress.update(1)
1554
+ return tasks
1555
+
1556
+ # Parallel extraction with progress
1557
+ file_task_lists = await asyncio.gather(*[extract_from_file(path) for path in paths])
1558
+ all_tasks = [task for tasks in file_task_lists for task in tasks]
1559
+
1560
+ progress.close()
1561
+ field_progress.close()
1562
+ nonlocal extract_duration, repair_start_time
1563
+ extract_duration = time.monotonic() - extract_start_time
1564
+
1565
+ # Update diagram progress total
1566
+ diagram_progress.total = len(all_tasks)
1567
+ diagram_progress.refresh()
1568
+
1569
+ if not all_tasks:
1570
+ return stats_total
1571
+
1572
+ # Phase 2: Global parallel repair (flatten all batches)
1573
+ repair_start_time = time.monotonic()
1574
+ file_replacements, errors = await repair_all_diagrams_global(
1575
+ all_tasks,
1576
+ batch_size,
1577
+ workers, # Use workers for global batch concurrency
1578
+ provider,
1579
+ model_name,
1580
+ api_key,
1581
+ timeout,
1582
+ max_retries,
1583
+ client,
1584
+ stats_total,
1585
+ progress_cb=lambda: diagram_progress.update(1) if not only_show_error else None,
1586
+ )
1587
+
1588
+ error_records.extend(errors)
1589
+ diagram_progress.close()
1590
+ nonlocal repair_duration
1591
+ repair_duration = time.monotonic() - repair_start_time
1592
+
1593
+ # Phase 3: Write back to files
1594
+ if not only_show_error:
1595
+ write_progress = tqdm(total=len(paths), desc="write", unit="file")
1596
+
1597
+ for path in paths:
1598
+ replacements = file_replacements.get(path, [])
1599
+ output_path = output_map[path]
1600
+
1601
+ if json_mode:
1602
+ # For JSON, apply replacements to fields
1603
+ raw_text = await asyncio.to_thread(read_text, path)
1604
+ items, payload, template_tag = _load_json_payload(path)
1605
+ cursor = 0
1606
+
1607
+ for item_index, item in enumerate(items):
1608
+ if not isinstance(item, dict):
1609
+ continue
1610
+ template = _resolve_item_template(item, template_tag)
1611
+ fields = _template_markdown_fields(template)
1612
+
1613
+ for field in fields:
1614
+ value = item.get(field)
1615
+ if not isinstance(value, str):
1616
+ continue
1617
+ field_path = f"papers[{item_index}].{field}"
1618
+
1619
+ # Find replacements for this specific field
1620
+ field_replacements = [
1621
+ (start, end, repl)
1622
+ for start, end, repl in replacements
1623
+ if any(
1624
+ t.field_path == field_path and t.item_index == item_index and t.span.start == start
1625
+ for t in all_tasks
1626
+ if t.file_path == path
1627
+ )
1628
+ ]
1629
+
1630
+ if field_replacements:
1631
+ updated_value = apply_replacements(value, field_replacements)
1632
+ item[field] = updated_value
1633
+
1634
+ output_data: Any = items if payload is None else {**payload, "papers": items}
1635
+ serialized = json.dumps(output_data, ensure_ascii=False, indent=2)
1636
+ await asyncio.to_thread(output_path.write_text, f"{serialized}\n", encoding="utf-8")
1637
+ else:
1638
+ # For markdown, apply replacements directly
1639
+ content = await asyncio.to_thread(read_text, path)
1640
+ updated = apply_replacements(content, replacements)
1361
1641
  await asyncio.to_thread(output_path.write_text, updated, encoding="utf-8")
1362
- error_records.extend(errors)
1363
- return stats
1364
-
1365
- async def runner(path: Path) -> None:
1366
- async with semaphore:
1367
- stats = await handle_path(path)
1368
- stats_total.diagrams_total += stats.diagrams_total
1369
- stats_total.diagrams_invalid += stats.diagrams_invalid
1370
- stats_total.diagrams_repaired += stats.diagrams_repaired
1371
- stats_total.diagrams_failed += stats.diagrams_failed
1372
- async with progress_lock:
1373
- progress.update(1)
1374
-
1375
- await asyncio.gather(*(runner(path) for path in paths))
1642
+
1643
+ write_progress.update(1)
1644
+
1645
+ write_progress.close()
1646
+
1376
1647
  return stats_total
1377
1648
 
1378
1649
  try:
1379
1650
  stats = asyncio.run(run())
1380
1651
  finally:
1381
1652
  progress.close()
1653
+ field_progress.close()
1382
1654
  diagram_progress.close()
1383
1655
 
1384
1656
  if report_target and error_records:
@@ -1396,6 +1668,11 @@ def recognize_fix_mermaid(
1396
1668
  ("Invalid", str(stats.diagrams_invalid)),
1397
1669
  ("Repaired", str(stats.diagrams_repaired)),
1398
1670
  ("Failed", str(stats.diagrams_failed)),
1671
+ ("Extract time", _format_duration(extract_duration)),
1672
+ ("Extract avg", f"{extract_duration / stats.diagrams_total:.3f}s/diagram" if stats.diagrams_total > 0 else "-"),
1673
+ ("Repair time", _format_duration(repair_duration)),
1674
+ ("Repair avg", f"{repair_duration / stats.diagrams_invalid:.3f}s/diagram" if stats.diagrams_invalid > 0 else "-"),
1675
+ ("Retry failed", "yes" if retry_failed else "no"),
1399
1676
  ("Only show error", "yes" if only_show_error else "no"),
1400
1677
  ("Report", _relative_path(report_target) if report_target else "-"),
1401
1678
  ]