deepresearch-flow 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepresearch_flow/paper/web/constants.py +4 -4
- deepresearch_flow/paper/web/templates/detail.html +3 -3
- deepresearch_flow/paper/web/templates/index.html +3 -3
- deepresearch_flow/recognize/cli.py +805 -26
- deepresearch_flow/recognize/katex_check.js +29 -0
- deepresearch_flow/recognize/math.py +719 -0
- deepresearch_flow/recognize/mermaid.py +690 -0
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.4.1.dist-info}/METADATA +56 -3
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.4.1.dist-info}/RECORD +13 -10
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.4.1.dist-info}/WHEEL +0 -0
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.4.1.dist-info}/entry_points.txt +0 -0
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
"""Mermaid validation and repair helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import re
|
|
11
|
+
import shutil
|
|
12
|
+
import subprocess
|
|
13
|
+
import tempfile
|
|
14
|
+
from typing import Any, Callable, Iterable
|
|
15
|
+
|
|
16
|
+
import httpx
|
|
17
|
+
|
|
18
|
+
from deepresearch_flow.paper.llm import backoff_delay, call_provider
|
|
19
|
+
from deepresearch_flow.paper.providers.base import ProviderError
|
|
20
|
+
from deepresearch_flow.paper.utils import parse_json, short_hash
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class MermaidSpan:
|
|
27
|
+
start: int
|
|
28
|
+
end: int
|
|
29
|
+
content: str
|
|
30
|
+
line: int
|
|
31
|
+
context: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class MermaidIssue:
|
|
36
|
+
issue_id: str
|
|
37
|
+
span: MermaidSpan
|
|
38
|
+
errors: list[str]
|
|
39
|
+
field_path: str | None
|
|
40
|
+
item_index: int | None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class MermaidFixStats:
|
|
45
|
+
diagrams_total: int = 0
|
|
46
|
+
diagrams_invalid: int = 0
|
|
47
|
+
diagrams_repaired: int = 0
|
|
48
|
+
diagrams_failed: int = 0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
_MMDC_PATH: str | None = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def require_mmdc() -> None:
|
|
55
|
+
global _MMDC_PATH
|
|
56
|
+
if _MMDC_PATH:
|
|
57
|
+
return
|
|
58
|
+
local_mmdc = Path.cwd() / "node_modules" / ".bin" / "mmdc"
|
|
59
|
+
if local_mmdc.exists():
|
|
60
|
+
_MMDC_PATH = str(local_mmdc)
|
|
61
|
+
return
|
|
62
|
+
_MMDC_PATH = shutil.which("mmdc")
|
|
63
|
+
if not _MMDC_PATH:
|
|
64
|
+
raise RuntimeError("mmdc (mermaid-cli) not found; install with npm i -g @mermaid-js/mermaid-cli")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def extract_mermaid_spans(text: str, context_chars: int) -> list[MermaidSpan]:
|
|
68
|
+
spans: list[MermaidSpan] = []
|
|
69
|
+
pattern = re.compile(r"```\s*mermaid\s*\n([\s\S]*?)```", re.IGNORECASE)
|
|
70
|
+
for match in pattern.finditer(text):
|
|
71
|
+
content = match.group(1)
|
|
72
|
+
content_start = match.start(1)
|
|
73
|
+
content_end = match.end(1)
|
|
74
|
+
line = text.count("\n", 0, content_start) + 1
|
|
75
|
+
context = text[max(0, match.start() - context_chars) : match.end() + context_chars]
|
|
76
|
+
spans.append(
|
|
77
|
+
MermaidSpan(
|
|
78
|
+
start=content_start,
|
|
79
|
+
end=content_end,
|
|
80
|
+
content=content,
|
|
81
|
+
line=line,
|
|
82
|
+
context=context,
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
return spans
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def cleanup_mermaid(text: str) -> str:
|
|
89
|
+
cleaned = text.replace("\r\n", "\n").replace("\r", "\n").lstrip("\ufeff")
|
|
90
|
+
cleaned = cleaned.replace("\u2028", "\n").replace("\u2029", "\n")
|
|
91
|
+
cleaned = cleaned.replace("“", "\"").replace("”", "\"").replace("‘", "'").replace("’", "'")
|
|
92
|
+
cleaned = _expand_escaped_newlines(cleaned)
|
|
93
|
+
cleaned = _normalize_mermaid_lines(cleaned)
|
|
94
|
+
cleaned = _normalize_subgraph_lines(cleaned)
|
|
95
|
+
cleaned = _repair_edge_labels(cleaned)
|
|
96
|
+
cleaned = _repair_missing_label_arrows(cleaned)
|
|
97
|
+
cleaned = _normalize_label_linebreaks(cleaned)
|
|
98
|
+
cleaned = _normalize_cylinder_labels(cleaned)
|
|
99
|
+
cleaned = _wrap_html_labels(cleaned)
|
|
100
|
+
cleaned = _close_unbalanced_labels(cleaned)
|
|
101
|
+
cleaned = _split_compacted_statements(cleaned)
|
|
102
|
+
cleaned = _split_chained_edges(cleaned)
|
|
103
|
+
cleaned = _repair_dot_label_edges(cleaned)
|
|
104
|
+
cleaned = _expand_multi_source_edges(cleaned)
|
|
105
|
+
cleaned = _prefix_orphan_edges(cleaned)
|
|
106
|
+
cleaned = _dedupe_subgraph_ids(cleaned)
|
|
107
|
+
return cleaned
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _expand_escaped_newlines(text: str) -> str:
|
|
111
|
+
out: list[str] = []
|
|
112
|
+
depth = 0
|
|
113
|
+
i = 0
|
|
114
|
+
while i < len(text):
|
|
115
|
+
ch = text[i]
|
|
116
|
+
if ch in "[({":
|
|
117
|
+
depth += 1
|
|
118
|
+
elif ch in "])}":
|
|
119
|
+
depth = max(0, depth - 1)
|
|
120
|
+
if ch == "\\" and i + 1 < len(text) and text[i + 1] == "n":
|
|
121
|
+
out.append("<br/>" if depth > 0 else "\n")
|
|
122
|
+
i += 2
|
|
123
|
+
continue
|
|
124
|
+
out.append(ch)
|
|
125
|
+
i += 1
|
|
126
|
+
return "".join(out)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _normalize_mermaid_lines(text: str) -> str:
|
|
130
|
+
out_lines: list[str] = []
|
|
131
|
+
for raw in text.splitlines():
|
|
132
|
+
line = raw.strip()
|
|
133
|
+
if not line:
|
|
134
|
+
out_lines.append("")
|
|
135
|
+
continue
|
|
136
|
+
if line.startswith("%%"):
|
|
137
|
+
out_lines.append(line)
|
|
138
|
+
continue
|
|
139
|
+
if "%%" in line:
|
|
140
|
+
code, comment = line.split("%%", 1)
|
|
141
|
+
code = code.strip()
|
|
142
|
+
if code:
|
|
143
|
+
out_lines.extend(_split_statements(code))
|
|
144
|
+
comment = comment.strip()
|
|
145
|
+
if comment:
|
|
146
|
+
out_lines.append(f"%% {comment}")
|
|
147
|
+
continue
|
|
148
|
+
line = _split_statements(line)
|
|
149
|
+
out_lines.extend(line)
|
|
150
|
+
return "\n".join(out_lines)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _split_statements(line: str) -> list[str]:
|
|
154
|
+
parts: list[str] = []
|
|
155
|
+
current: list[str] = []
|
|
156
|
+
depth = 0
|
|
157
|
+
for ch in line:
|
|
158
|
+
if ch in "[({":
|
|
159
|
+
depth += 1
|
|
160
|
+
elif ch in "])}":
|
|
161
|
+
depth = max(0, depth - 1)
|
|
162
|
+
if ch == ";" and depth == 0:
|
|
163
|
+
part = "".join(current).strip()
|
|
164
|
+
if part:
|
|
165
|
+
parts.append(part)
|
|
166
|
+
current = []
|
|
167
|
+
continue
|
|
168
|
+
current.append(ch)
|
|
169
|
+
tail = "".join(current).strip()
|
|
170
|
+
if tail:
|
|
171
|
+
parts.append(tail)
|
|
172
|
+
return parts or [line]
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _normalize_subgraph_lines(text: str) -> str:
|
|
176
|
+
lines = text.splitlines()
|
|
177
|
+
normalized: list[str] = []
|
|
178
|
+
counter = 0
|
|
179
|
+
for line in lines:
|
|
180
|
+
match = re.match(r"\s*subgraph\s+(.+)", line)
|
|
181
|
+
if not match:
|
|
182
|
+
normalized.append(line)
|
|
183
|
+
continue
|
|
184
|
+
rest = match.group(1).strip()
|
|
185
|
+
label_match = re.match(r"(.+?)\s*\[(.+)\]\s*$", rest)
|
|
186
|
+
if label_match:
|
|
187
|
+
counter += 1
|
|
188
|
+
label = label_match.group(2).replace("[", "(").replace("]", ")").strip()
|
|
189
|
+
label = _quote_label_text(label)
|
|
190
|
+
sub_id = f"subgraph_{counter}"
|
|
191
|
+
normalized.append(f"subgraph {sub_id} [{label}]")
|
|
192
|
+
continue
|
|
193
|
+
counter += 1
|
|
194
|
+
label = rest.replace("[", "(").replace("]", ")").strip()
|
|
195
|
+
label = _quote_label_text(label)
|
|
196
|
+
sub_id = f"subgraph_{counter}"
|
|
197
|
+
normalized.append(f"subgraph {sub_id} [{label}]")
|
|
198
|
+
return "\n".join(normalized)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _normalize_label_linebreaks(text: str) -> str:
|
|
202
|
+
def fix_block(block: str) -> str:
|
|
203
|
+
if "-->" in block or "-.->" in block or "==>" in block or "\nsubgraph" in block or "\nend" in block:
|
|
204
|
+
return block
|
|
205
|
+
return block.replace("\n", "<br/>")
|
|
206
|
+
|
|
207
|
+
return re.sub(
|
|
208
|
+
r"\[[^\]]*\]",
|
|
209
|
+
lambda match: fix_block(match.group(0)),
|
|
210
|
+
text,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _repair_edge_labels(text: str) -> str:
|
|
215
|
+
return re.sub(r"-->\s*\[([^\]]+)\]\s*", r"-->|\1| ", text)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _repair_missing_label_arrows(text: str) -> str:
|
|
219
|
+
return re.sub(r"--\s*([^|>]+)\|\s*", r"-->|\1| ", text)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _split_compacted_statements(text: str) -> str:
|
|
223
|
+
token_start = r"[A-Za-z0-9_\u4e00-\u9fff]"
|
|
224
|
+
out = text
|
|
225
|
+
out = re.sub(
|
|
226
|
+
r"^(\s*(?:graph|flowchart)\s+[A-Za-z]{2})\s*(?=\S)",
|
|
227
|
+
r"\1\n",
|
|
228
|
+
out,
|
|
229
|
+
flags=re.MULTILINE,
|
|
230
|
+
)
|
|
231
|
+
out = re.sub(
|
|
232
|
+
rf"^(\s*(?:graph|flowchart)\s+[A-Za-z]{{2}})(?={token_start})",
|
|
233
|
+
r"\1\n",
|
|
234
|
+
out,
|
|
235
|
+
flags=re.MULTILINE,
|
|
236
|
+
)
|
|
237
|
+
out = re.sub(rf"([)\]}}])\s*(?={token_start}+\s*-->)", r"\1\n", out)
|
|
238
|
+
out = re.sub(rf"([)\]}}])\s*(?={token_start}+\s*[\[\(\{{])", r"\1\n", out)
|
|
239
|
+
out = re.sub(rf"([)\]}}])(?={token_start})", r"\1\n", out)
|
|
240
|
+
return out
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _prefix_orphan_edges(text: str) -> str:
|
|
244
|
+
lines = text.splitlines()
|
|
245
|
+
for idx, line in enumerate(lines):
|
|
246
|
+
if re.match(r"^\s*(-->|-\.-|==>)", line):
|
|
247
|
+
lines[idx] = f"Start {line}"
|
|
248
|
+
return "\n".join(lines)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _normalize_cylinder_labels(text: str) -> str:
|
|
252
|
+
lines = []
|
|
253
|
+
for line in text.splitlines():
|
|
254
|
+
if "[(" in line:
|
|
255
|
+
if ")]" in line:
|
|
256
|
+
line = line.replace("[(", '["(').replace(")]", ')"]')
|
|
257
|
+
else:
|
|
258
|
+
def fix_label(match: re.Match[str]) -> str:
|
|
259
|
+
inner = match.group(1).strip()
|
|
260
|
+
wrapped = f"({inner}" if ")" in inner else f"({inner})"
|
|
261
|
+
return f'[\"{wrapped}\"]'
|
|
262
|
+
|
|
263
|
+
line = re.sub(r"\[\(([^]]+)\]", fix_label, line)
|
|
264
|
+
lines.append(line)
|
|
265
|
+
return "\n".join(lines)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _wrap_html_labels(text: str) -> str:
|
|
269
|
+
def repl(match: re.Match[str]) -> str:
|
|
270
|
+
inner = match.group(1).replace('"', "'")
|
|
271
|
+
return f'["{inner}"]'
|
|
272
|
+
|
|
273
|
+
return re.sub(
|
|
274
|
+
r"\[([^\]]*<br\s*/?>[^\]]*)\]",
|
|
275
|
+
repl,
|
|
276
|
+
text,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _close_unbalanced_labels(text: str) -> str:
|
|
281
|
+
lines = []
|
|
282
|
+
for line in text.splitlines():
|
|
283
|
+
if line.startswith("%%"):
|
|
284
|
+
lines.append(line)
|
|
285
|
+
continue
|
|
286
|
+
opens = line.count("[")
|
|
287
|
+
closes = line.count("]")
|
|
288
|
+
if opens > closes:
|
|
289
|
+
line = line + ("]" * (opens - closes))
|
|
290
|
+
lines.append(line)
|
|
291
|
+
return "\n".join(lines)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _quote_label_text(label: str) -> str:
|
|
295
|
+
stripped = label.strip()
|
|
296
|
+
if stripped.startswith(('"', "'")) and stripped.endswith(('"', "'")):
|
|
297
|
+
return stripped
|
|
298
|
+
return f'"{stripped.replace(chr(34), chr(39))}"'
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _split_chained_edges(text: str) -> str:
|
|
302
|
+
out: list[str] = []
|
|
303
|
+
arrows = ("-->", "-.->", "==>")
|
|
304
|
+
for line in text.splitlines():
|
|
305
|
+
if line.startswith("%%") or "subgraph" in line or line.strip() == "end":
|
|
306
|
+
out.append(line)
|
|
307
|
+
continue
|
|
308
|
+
segments: list[str] = []
|
|
309
|
+
tokens: list[str] = []
|
|
310
|
+
buf: list[str] = []
|
|
311
|
+
depth = 0
|
|
312
|
+
i = 0
|
|
313
|
+
while i < len(line):
|
|
314
|
+
ch = line[i]
|
|
315
|
+
if ch in "[({":
|
|
316
|
+
depth += 1
|
|
317
|
+
elif ch in "])}":
|
|
318
|
+
depth = max(0, depth - 1)
|
|
319
|
+
if depth == 0:
|
|
320
|
+
matched = None
|
|
321
|
+
for arrow in arrows:
|
|
322
|
+
if line.startswith(arrow, i):
|
|
323
|
+
matched = arrow
|
|
324
|
+
break
|
|
325
|
+
if matched:
|
|
326
|
+
segments.append("".join(buf).strip())
|
|
327
|
+
tokens.append(matched)
|
|
328
|
+
buf = []
|
|
329
|
+
i += len(matched)
|
|
330
|
+
continue
|
|
331
|
+
buf.append(ch)
|
|
332
|
+
i += 1
|
|
333
|
+
tail = "".join(buf).strip()
|
|
334
|
+
if tail:
|
|
335
|
+
segments.append(tail)
|
|
336
|
+
if len(segments) <= 2:
|
|
337
|
+
out.append(line)
|
|
338
|
+
continue
|
|
339
|
+
for idx, arrow in enumerate(tokens):
|
|
340
|
+
left = segments[idx]
|
|
341
|
+
right = segments[idx + 1]
|
|
342
|
+
if left and right:
|
|
343
|
+
out.append(f"{left} {arrow} {right}")
|
|
344
|
+
return "\n".join(out)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _repair_dot_label_edges(text: str) -> str:
|
|
348
|
+
def fix_line(line: str) -> str:
|
|
349
|
+
return re.sub(r"-\.\s*([^|>]+?)\s*\.-\s*>", r"-.->|\1|", line)
|
|
350
|
+
|
|
351
|
+
return "\n".join(fix_line(line) for line in text.splitlines())
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _expand_multi_source_edges(text: str) -> str:
|
|
355
|
+
lines = text.splitlines()
|
|
356
|
+
out: list[str] = []
|
|
357
|
+
pattern = re.compile(
|
|
358
|
+
r"^(?P<indent>\s*)(?P<left>[A-Za-z0-9_]+(?:\s*&\s*[A-Za-z0-9_]+)+)\s*-->\s*(?P<right>.+)$"
|
|
359
|
+
)
|
|
360
|
+
for line in lines:
|
|
361
|
+
match = pattern.match(line)
|
|
362
|
+
if not match:
|
|
363
|
+
out.append(line)
|
|
364
|
+
continue
|
|
365
|
+
indent = match.group("indent")
|
|
366
|
+
right = match.group("right").strip()
|
|
367
|
+
for node in match.group("left").split("&"):
|
|
368
|
+
out.append(f"{indent}{node.strip()} --> {right}")
|
|
369
|
+
return "\n".join(out)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def _dedupe_subgraph_ids(text: str) -> str:
|
|
373
|
+
lines = text.splitlines()
|
|
374
|
+
idx = 0
|
|
375
|
+
while idx < len(lines):
|
|
376
|
+
line = lines[idx]
|
|
377
|
+
match = re.match(r"\s*subgraph\s+([A-Za-z0-9_]+)\b", line)
|
|
378
|
+
if not match:
|
|
379
|
+
idx += 1
|
|
380
|
+
continue
|
|
381
|
+
sub_id = match.group(1)
|
|
382
|
+
end_idx = idx + 1
|
|
383
|
+
while end_idx < len(lines) and not re.match(r"\s*end\b", lines[end_idx]):
|
|
384
|
+
end_idx += 1
|
|
385
|
+
conflict = any(re.match(rf"\s*{re.escape(sub_id)}\b", ln) for ln in lines[idx + 1 : end_idx])
|
|
386
|
+
if conflict:
|
|
387
|
+
new_id = f"{sub_id}_group"
|
|
388
|
+
lines[idx] = line.replace(sub_id, new_id, 1)
|
|
389
|
+
idx = end_idx + 1
|
|
390
|
+
return "\n".join(lines)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _mermaid_temp_dir() -> Path:
|
|
394
|
+
base_dir = Path("/tmp/mermaid")
|
|
395
|
+
base_dir.mkdir(parents=True, exist_ok=True)
|
|
396
|
+
return base_dir
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def validate_mermaid(mmd_text: str) -> str | None:
|
|
400
|
+
require_mmdc()
|
|
401
|
+
base_dir = _mermaid_temp_dir()
|
|
402
|
+
with tempfile.TemporaryDirectory(dir=base_dir, prefix="mmdc-") as td:
|
|
403
|
+
work_dir = Path(td)
|
|
404
|
+
in_file = work_dir / "diagram.mmd"
|
|
405
|
+
out_file = work_dir / "diagram.svg"
|
|
406
|
+
in_file.write_text(mmd_text, encoding="utf-8")
|
|
407
|
+
cmd = [
|
|
408
|
+
_MMDC_PATH or "mmdc",
|
|
409
|
+
"-i",
|
|
410
|
+
str(in_file),
|
|
411
|
+
"-o",
|
|
412
|
+
str(out_file),
|
|
413
|
+
"--quiet",
|
|
414
|
+
]
|
|
415
|
+
proc = subprocess.run(cmd, capture_output=True, text=True, encoding="utf-8")
|
|
416
|
+
if proc.returncode == 0 and out_file.exists() and out_file.stat().st_size > 0:
|
|
417
|
+
return None
|
|
418
|
+
msg = (proc.stderr or "") + "\n" + (proc.stdout or "")
|
|
419
|
+
msg = msg.strip() or f"mmdc failed with code {proc.returncode}"
|
|
420
|
+
return msg
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def apply_replacements(text: str, replacements: list[tuple[int, int, str]]) -> str:
|
|
424
|
+
if not replacements:
|
|
425
|
+
return text
|
|
426
|
+
updated = text
|
|
427
|
+
for start, end, value in sorted(replacements, key=lambda item: item[0], reverse=True):
|
|
428
|
+
updated = updated[:start] + value + updated[end:]
|
|
429
|
+
return updated
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def repair_schema() -> dict[str, Any]:
|
|
433
|
+
return {
|
|
434
|
+
"type": "object",
|
|
435
|
+
"properties": {
|
|
436
|
+
"items": {
|
|
437
|
+
"type": "array",
|
|
438
|
+
"items": {
|
|
439
|
+
"type": "object",
|
|
440
|
+
"properties": {
|
|
441
|
+
"id": {"type": "string"},
|
|
442
|
+
"mermaid": {"type": "string"},
|
|
443
|
+
},
|
|
444
|
+
"required": ["id", "mermaid"],
|
|
445
|
+
},
|
|
446
|
+
}
|
|
447
|
+
},
|
|
448
|
+
"required": ["items"],
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def build_repair_messages(issues: list[MermaidIssue]) -> list[dict[str, str]]:
|
|
453
|
+
payload = [
|
|
454
|
+
{
|
|
455
|
+
"id": issue.issue_id,
|
|
456
|
+
"mermaid": issue.span.content,
|
|
457
|
+
"errors": issue.errors,
|
|
458
|
+
"context": issue.span.context,
|
|
459
|
+
}
|
|
460
|
+
for issue in issues
|
|
461
|
+
]
|
|
462
|
+
system = (
|
|
463
|
+
"You repair Mermaid diagrams. Fix syntax errors only and keep the "
|
|
464
|
+
"meaning unchanged. Return JSON with key 'items' and each item "
|
|
465
|
+
"containing {\"id\", \"mermaid\"}. Output JSON only.\n\n"
|
|
466
|
+
"Use this minimal safe subset for all repaired Mermaid output:\n"
|
|
467
|
+
"- Only use: graph TD\n"
|
|
468
|
+
"- Node IDs: ASCII letters/digits/underscore only\n"
|
|
469
|
+
"- Node labels: id[\"中文...\"]\n"
|
|
470
|
+
"- Line breaks in labels: use <br/> only\n"
|
|
471
|
+
"- Subgraphs: use subgraph sgN[\"中文标题\"] (no Chinese in IDs)\n"
|
|
472
|
+
"- No inline comments (remove %% lines)\n"
|
|
473
|
+
"- Do not use special shapes like [(...)], just use [\"...\"]\n"
|
|
474
|
+
"- One statement per line; do not glue multiple edges on one line\n"
|
|
475
|
+
"- Do not use multi-source edges with &: expand into multiple edges\n"
|
|
476
|
+
)
|
|
477
|
+
user = json.dumps({"items": payload}, ensure_ascii=False, indent=2)
|
|
478
|
+
return [
|
|
479
|
+
{"role": "system", "content": system},
|
|
480
|
+
{"role": "user", "content": user},
|
|
481
|
+
]
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def iter_batches(items: list[MermaidIssue], batch_size: int) -> Iterable[list[MermaidIssue]]:
|
|
485
|
+
for idx in range(0, len(items), batch_size):
|
|
486
|
+
yield items[idx : idx + batch_size]
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _parse_repairs(response: str) -> dict[str, str]:
|
|
490
|
+
parsed = parse_json(response)
|
|
491
|
+
items = parsed.get("items", [])
|
|
492
|
+
repairs: dict[str, str] = {}
|
|
493
|
+
if isinstance(items, list):
|
|
494
|
+
for item in items:
|
|
495
|
+
if not isinstance(item, dict):
|
|
496
|
+
continue
|
|
497
|
+
issue_id = item.get("id")
|
|
498
|
+
mermaid = item.get("mermaid")
|
|
499
|
+
if isinstance(issue_id, str) and isinstance(mermaid, str):
|
|
500
|
+
repairs[issue_id] = mermaid
|
|
501
|
+
return repairs
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def strip_mermaid_fences(text: str) -> str:
|
|
505
|
+
stripped = text.strip()
|
|
506
|
+
if stripped.startswith("```") and stripped.endswith("```"):
|
|
507
|
+
stripped = stripped.strip("`")
|
|
508
|
+
stripped = stripped.replace("mermaid", "", 1)
|
|
509
|
+
return stripped.strip()
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
async def repair_batch(
|
|
513
|
+
issues: list[MermaidIssue],
|
|
514
|
+
provider,
|
|
515
|
+
model_name: str,
|
|
516
|
+
api_key: str | None,
|
|
517
|
+
timeout: float,
|
|
518
|
+
max_retries: int,
|
|
519
|
+
client: httpx.AsyncClient,
|
|
520
|
+
) -> tuple[dict[str, str], str | None]:
|
|
521
|
+
messages = build_repair_messages(issues)
|
|
522
|
+
schema = repair_schema()
|
|
523
|
+
last_error: str | None = None
|
|
524
|
+
for attempt in range(max_retries + 1):
|
|
525
|
+
try:
|
|
526
|
+
response = await call_provider(
|
|
527
|
+
provider,
|
|
528
|
+
model_name,
|
|
529
|
+
messages,
|
|
530
|
+
schema,
|
|
531
|
+
api_key,
|
|
532
|
+
timeout,
|
|
533
|
+
provider.structured_mode,
|
|
534
|
+
client,
|
|
535
|
+
max_tokens=provider.max_tokens,
|
|
536
|
+
)
|
|
537
|
+
repairs = _parse_repairs(response)
|
|
538
|
+
return repairs, None
|
|
539
|
+
except (ValueError, TypeError) as exc:
|
|
540
|
+
last_error = f"parse_error: {exc}"
|
|
541
|
+
if attempt < max_retries:
|
|
542
|
+
await asyncio.sleep(backoff_delay(1.0, attempt + 1, 20.0))
|
|
543
|
+
continue
|
|
544
|
+
return {}, last_error
|
|
545
|
+
except ProviderError as exc:
|
|
546
|
+
last_error = str(exc)
|
|
547
|
+
if exc.retryable and attempt < max_retries:
|
|
548
|
+
await asyncio.sleep(backoff_delay(1.0, attempt + 1, 20.0))
|
|
549
|
+
continue
|
|
550
|
+
return {}, last_error
|
|
551
|
+
except Exception as exc: # pragma: no cover - safety net
|
|
552
|
+
last_error = str(exc)
|
|
553
|
+
if attempt < max_retries:
|
|
554
|
+
await asyncio.sleep(backoff_delay(1.0, attempt + 1, 20.0))
|
|
555
|
+
continue
|
|
556
|
+
return {}, last_error
|
|
557
|
+
return {}, last_error
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
async def fix_mermaid_text(
|
|
561
|
+
text: str,
|
|
562
|
+
file_path: str,
|
|
563
|
+
line_offset: int,
|
|
564
|
+
field_path: str | None,
|
|
565
|
+
item_index: int | None,
|
|
566
|
+
provider,
|
|
567
|
+
model_name: str,
|
|
568
|
+
api_key: str | None,
|
|
569
|
+
timeout: float,
|
|
570
|
+
max_retries: int,
|
|
571
|
+
batch_size: int,
|
|
572
|
+
context_chars: int,
|
|
573
|
+
client: httpx.AsyncClient,
|
|
574
|
+
stats: MermaidFixStats,
|
|
575
|
+
repair_enabled: bool = True,
|
|
576
|
+
spans: list[MermaidSpan] | None = None,
|
|
577
|
+
progress_cb: Callable[[], None] | None = None,
|
|
578
|
+
) -> tuple[str, list[dict[str, Any]]]:
|
|
579
|
+
replacements: list[tuple[int, int, str]] = []
|
|
580
|
+
issues: list[MermaidIssue] = []
|
|
581
|
+
if spans is None:
|
|
582
|
+
spans = extract_mermaid_spans(text, context_chars)
|
|
583
|
+
stats.diagrams_total += len(spans)
|
|
584
|
+
file_id = short_hash(file_path)
|
|
585
|
+
for idx, span in enumerate(spans):
|
|
586
|
+
errors: list[str] = []
|
|
587
|
+
original = span.content
|
|
588
|
+
validation = validate_mermaid(original)
|
|
589
|
+
candidate = original
|
|
590
|
+
if validation:
|
|
591
|
+
candidate = cleanup_mermaid(original)
|
|
592
|
+
if candidate != original:
|
|
593
|
+
candidate_validation = validate_mermaid(candidate)
|
|
594
|
+
if not candidate_validation:
|
|
595
|
+
replacements.append((span.start, span.end, candidate))
|
|
596
|
+
if progress_cb:
|
|
597
|
+
progress_cb()
|
|
598
|
+
continue
|
|
599
|
+
validation = candidate_validation
|
|
600
|
+
errors.append(validation)
|
|
601
|
+
stats.diagrams_invalid += 1
|
|
602
|
+
issue_id = f"{file_id}:{idx}"
|
|
603
|
+
issues.append(
|
|
604
|
+
MermaidIssue(
|
|
605
|
+
issue_id=issue_id,
|
|
606
|
+
span=span,
|
|
607
|
+
errors=errors,
|
|
608
|
+
field_path=field_path,
|
|
609
|
+
item_index=item_index,
|
|
610
|
+
)
|
|
611
|
+
)
|
|
612
|
+
if progress_cb:
|
|
613
|
+
progress_cb()
|
|
614
|
+
|
|
615
|
+
error_records: list[dict[str, Any]] = []
|
|
616
|
+
if issues and repair_enabled:
|
|
617
|
+
for batch in iter_batches(issues, batch_size):
|
|
618
|
+
repairs, error = await repair_batch(
|
|
619
|
+
batch,
|
|
620
|
+
provider,
|
|
621
|
+
model_name,
|
|
622
|
+
api_key,
|
|
623
|
+
timeout,
|
|
624
|
+
max_retries,
|
|
625
|
+
client,
|
|
626
|
+
)
|
|
627
|
+
if error:
|
|
628
|
+
for issue in batch:
|
|
629
|
+
stats.diagrams_failed += 1
|
|
630
|
+
error_records.append(
|
|
631
|
+
{
|
|
632
|
+
"path": file_path,
|
|
633
|
+
"line": line_offset + issue.span.line - 1,
|
|
634
|
+
"mermaid": issue.span.content,
|
|
635
|
+
"errors": issue.errors + [f"llm_error: {error}"],
|
|
636
|
+
"field_path": issue.field_path,
|
|
637
|
+
"item_index": issue.item_index,
|
|
638
|
+
}
|
|
639
|
+
)
|
|
640
|
+
continue
|
|
641
|
+
|
|
642
|
+
for issue in batch:
|
|
643
|
+
repaired = repairs.get(issue.issue_id)
|
|
644
|
+
if not repaired:
|
|
645
|
+
stats.diagrams_failed += 1
|
|
646
|
+
error_records.append(
|
|
647
|
+
{
|
|
648
|
+
"path": file_path,
|
|
649
|
+
"line": line_offset + issue.span.line - 1,
|
|
650
|
+
"mermaid": issue.span.content,
|
|
651
|
+
"errors": issue.errors + ["llm_missing_output"],
|
|
652
|
+
"field_path": issue.field_path,
|
|
653
|
+
"item_index": issue.item_index,
|
|
654
|
+
}
|
|
655
|
+
)
|
|
656
|
+
continue
|
|
657
|
+
repaired = strip_mermaid_fences(repaired)
|
|
658
|
+
repaired = cleanup_mermaid(repaired)
|
|
659
|
+
validation = validate_mermaid(repaired)
|
|
660
|
+
if validation:
|
|
661
|
+
stats.diagrams_failed += 1
|
|
662
|
+
error_records.append(
|
|
663
|
+
{
|
|
664
|
+
"path": file_path,
|
|
665
|
+
"line": line_offset + issue.span.line - 1,
|
|
666
|
+
"mermaid": issue.span.content,
|
|
667
|
+
"errors": issue.errors + [validation],
|
|
668
|
+
"field_path": issue.field_path,
|
|
669
|
+
"item_index": issue.item_index,
|
|
670
|
+
}
|
|
671
|
+
)
|
|
672
|
+
continue
|
|
673
|
+
stats.diagrams_repaired += 1
|
|
674
|
+
replacements.append((issue.span.start, issue.span.end, repaired))
|
|
675
|
+
elif issues:
|
|
676
|
+
for issue in issues:
|
|
677
|
+
stats.diagrams_failed += 1
|
|
678
|
+
error_records.append(
|
|
679
|
+
{
|
|
680
|
+
"path": file_path,
|
|
681
|
+
"line": line_offset + issue.span.line - 1,
|
|
682
|
+
"mermaid": issue.span.content,
|
|
683
|
+
"errors": issue.errors + ["validation_only"],
|
|
684
|
+
"field_path": issue.field_path,
|
|
685
|
+
"item_index": issue.item_index,
|
|
686
|
+
}
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
updated = apply_replacements(text, replacements)
|
|
690
|
+
return updated, error_records
|