deepresearch-flow 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepresearch_flow/paper/db.py +34 -0
- deepresearch_flow/paper/web/app.py +106 -1
- deepresearch_flow/paper/web/constants.py +5 -4
- deepresearch_flow/paper/web/handlers/__init__.py +2 -1
- deepresearch_flow/paper/web/handlers/api.py +55 -0
- deepresearch_flow/paper/web/handlers/pages.py +105 -25
- deepresearch_flow/paper/web/markdown.py +60 -0
- deepresearch_flow/paper/web/pdfjs/web/viewer.html +57 -5
- deepresearch_flow/paper/web/pdfjs/web/viewer.js +5 -1
- deepresearch_flow/paper/web/static/js/detail.js +494 -125
- deepresearch_flow/paper/web/static/js/outline.js +48 -34
- deepresearch_flow/paper/web/static_assets.py +289 -0
- deepresearch_flow/paper/web/templates/detail.html +46 -69
- deepresearch_flow/paper/web/templates/index.html +3 -3
- deepresearch_flow/paper/web/templates.py +7 -4
- deepresearch_flow/recognize/cli.py +805 -26
- deepresearch_flow/recognize/katex_check.js +29 -0
- deepresearch_flow/recognize/math.py +719 -0
- deepresearch_flow/recognize/mermaid.py +690 -0
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.5.0.dist-info}/METADATA +117 -4
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.5.0.dist-info}/RECORD +25 -21
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.5.0.dist-info}/WHEEL +0 -0
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.5.0.dist-info}/entry_points.txt +0 -0
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {deepresearch_flow-0.4.0.dist-info → deepresearch_flow-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,719 @@
|
|
|
1
|
+
"""Math formula validation and repair helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
import asyncio
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
import atexit
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
import shutil
|
|
13
|
+
import subprocess
|
|
14
|
+
import threading
|
|
15
|
+
from typing import Any, Callable, Iterable
|
|
16
|
+
|
|
17
|
+
import httpx
|
|
18
|
+
|
|
19
|
+
from deepresearch_flow.paper.llm import backoff_delay, call_provider
|
|
20
|
+
from deepresearch_flow.paper.providers.base import ProviderError
|
|
21
|
+
from deepresearch_flow.paper.utils import parse_json, short_hash
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from pylatexenc.latexwalker import LatexWalker, LatexWalkerError
|
|
25
|
+
except ImportError: # pragma: no cover - dependency guard
|
|
26
|
+
LatexWalker = None
|
|
27
|
+
LatexWalkerError = Exception
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class FormulaSpan:
|
|
34
|
+
start: int
|
|
35
|
+
end: int
|
|
36
|
+
delimiter: str
|
|
37
|
+
content: str
|
|
38
|
+
line: int
|
|
39
|
+
context: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class FormulaIssue:
|
|
44
|
+
issue_id: str
|
|
45
|
+
span: FormulaSpan
|
|
46
|
+
errors: list[str]
|
|
47
|
+
cleaned: str
|
|
48
|
+
field_path: str | None
|
|
49
|
+
item_index: int | None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class MathFixStats:
|
|
54
|
+
formulas_total: int = 0
|
|
55
|
+
formulas_invalid: int = 0
|
|
56
|
+
formulas_cleaned: int = 0
|
|
57
|
+
formulas_repaired: int = 0
|
|
58
|
+
formulas_failed: int = 0
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
_KATEX_WARNED = False
|
|
62
|
+
_NODE_VALIDATOR: "NodeKatexValidator | None" = None
|
|
63
|
+
_NODE_KATEX_READY: bool | None = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def require_pylatexenc() -> None:
|
|
67
|
+
if LatexWalker is None:
|
|
68
|
+
raise RuntimeError("pylatexenc is required for fix-math")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _mask_regex(text: str, pattern: str, flags: int = 0) -> str:
|
|
72
|
+
masked = list(text)
|
|
73
|
+
for match in re.finditer(pattern, text, flags):
|
|
74
|
+
for idx in range(match.start(), match.end()):
|
|
75
|
+
masked[idx] = " "
|
|
76
|
+
return "".join(masked)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _mask_code(text: str) -> str:
|
|
80
|
+
masked = _mask_regex(text, r"```[\s\S]*?```")
|
|
81
|
+
masked = _mask_regex(masked, r"(?<!`)(`+)([^`\n]+?)\1(?!`)")
|
|
82
|
+
return masked
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def extract_math_spans(text: str, context_chars: int) -> list[FormulaSpan]:
|
|
86
|
+
masked = _mask_code(text)
|
|
87
|
+
spans: list[FormulaSpan] = []
|
|
88
|
+
for match in re.finditer(r"(?<!\\)\$\$([\s\S]+?)(?<!\\)\$\$", masked):
|
|
89
|
+
content = text[match.start() + 2 : match.end() - 2]
|
|
90
|
+
line = text.count("\n", 0, match.start()) + 1
|
|
91
|
+
context = text[max(0, match.start() - context_chars) : match.end() + context_chars]
|
|
92
|
+
spans.append(
|
|
93
|
+
FormulaSpan(
|
|
94
|
+
start=match.start(),
|
|
95
|
+
end=match.end(),
|
|
96
|
+
delimiter="$$",
|
|
97
|
+
content=content,
|
|
98
|
+
line=line,
|
|
99
|
+
context=context,
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
block_ranges = [(span.start, span.end) for span in spans]
|
|
104
|
+
inline_pattern = re.compile(r"(?<!\\)\$(?!\s|\$)([^$\n]+?)(?<!\\)\$(?!\$)")
|
|
105
|
+
for match in inline_pattern.finditer(masked):
|
|
106
|
+
if any(start <= match.start() < end for start, end in block_ranges):
|
|
107
|
+
continue
|
|
108
|
+
content = text[match.start() + 1 : match.end() - 1]
|
|
109
|
+
line = text.count("\n", 0, match.start()) + 1
|
|
110
|
+
context = text[max(0, match.start() - context_chars) : match.end() + context_chars]
|
|
111
|
+
spans.append(
|
|
112
|
+
FormulaSpan(
|
|
113
|
+
start=match.start(),
|
|
114
|
+
end=match.end(),
|
|
115
|
+
delimiter="$",
|
|
116
|
+
content=content,
|
|
117
|
+
line=line,
|
|
118
|
+
context=context,
|
|
119
|
+
)
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return sorted(spans, key=lambda span: span.start)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def cleanup_formula(text: str) -> str:
|
|
126
|
+
cleaned = text.replace("\u00a0", " ").strip()
|
|
127
|
+
cleaned = cleaned.replace("\r", "")
|
|
128
|
+
cleaned = cleaned.replace("\text", "\\text").replace("\tfrac", "\\tfrac")
|
|
129
|
+
cleaned = cleaned.replace("\t", "")
|
|
130
|
+
cleaned = re.sub(r"\\(?=\d)", "", cleaned)
|
|
131
|
+
cleaned = re.sub(r"\\n(?=[A-Za-z])", "", cleaned)
|
|
132
|
+
cleaned = re.sub(r"\\\s+([A-Za-z])", r"\\\1", cleaned)
|
|
133
|
+
cleaned = re.sub(r"[ \t]+\n", "\n", cleaned)
|
|
134
|
+
cleaned = re.sub(r"\n[ \t]+", "\n", cleaned)
|
|
135
|
+
cleaned = re.sub(r"\\\(\s*cdot\s*\)", r"\\cdot", cleaned)
|
|
136
|
+
cleaned = re.sub(r"\\\s*Max\b", r"\\max", cleaned)
|
|
137
|
+
cleaned = re.sub(r"\\\s*Min\b", r"\\min", cleaned)
|
|
138
|
+
cleaned = re.sub(r"\x08eta(?=[_\\^,\\s)])", r"\\beta", cleaned)
|
|
139
|
+
cleaned = re.sub(r"\x08eta\b", r"\\beta", cleaned)
|
|
140
|
+
cleaned = re.sub(r"\x08ar(?=[_\\{\\^\\s)])", r"\\bar", cleaned)
|
|
141
|
+
cleaned = re.sub(r"\x08ar\b", r"\\bar", cleaned)
|
|
142
|
+
cleaned = re.sub(r"\\\s+([A-Za-z]{2,})", r"\\text{\1}", cleaned)
|
|
143
|
+
cleaned = re.sub(
|
|
144
|
+
r"([A-Za-z0-9_{}\\]+)\^([A-Za-z0-9_{}]+)\^([A-Za-z0-9_{}]+)",
|
|
145
|
+
r"({\1}^{\2})^{\3}",
|
|
146
|
+
cleaned,
|
|
147
|
+
)
|
|
148
|
+
cleaned = re.sub(r"(\\right)\s*ceil\b", r"\\right\\rceil", cleaned)
|
|
149
|
+
cleaned = re.sub(r"(\\left)\s*ceil\b", r"\\left\\lceil", cleaned)
|
|
150
|
+
cleaned = re.sub(
|
|
151
|
+
r"\\Big\s*{\\(lfloor|lceil|rfloor|rceil|langle|rangle)}",
|
|
152
|
+
r"\\Big\\\1",
|
|
153
|
+
cleaned,
|
|
154
|
+
)
|
|
155
|
+
cleaned = re.sub(r"\x08egin\b", r"\\begin", cleaned)
|
|
156
|
+
cleaned = re.sub(r"\x08oldsymbol\b", r"\\boldsymbol", cleaned)
|
|
157
|
+
cleaned = re.sub(r"\^''", r"^{''}", cleaned)
|
|
158
|
+
cleaned = re.sub(r"\^'", r"^{\\prime}", cleaned)
|
|
159
|
+
cleaned = re.sub(r"\^_", r"^{*}", cleaned)
|
|
160
|
+
cleaned = re.sub(r"\\operatorname_\s*{", r"\\operatorname{", cleaned)
|
|
161
|
+
cleaned = re.sub(r"_\s*{\\times}", r"\\times", cleaned)
|
|
162
|
+
cleaned = re.sub(r"_\s*\\times\b", r"\\times", cleaned)
|
|
163
|
+
cleaned = re.sub(r"\\arg\s+\\max\s*_\s*{", r"\\arg\\max_{", cleaned)
|
|
164
|
+
cleaned = re.sub(r"\^\s*{\s*_?\s*}", "", cleaned)
|
|
165
|
+
cleaned = re.sub(
|
|
166
|
+
r"([A-Za-z0-9]+_\s*{[^}]+})\s*_\s*([A-Za-z])",
|
|
167
|
+
r"\1 \2",
|
|
168
|
+
cleaned,
|
|
169
|
+
)
|
|
170
|
+
cleaned = _collapse_spaced_text_commands(cleaned)
|
|
171
|
+
cleaned = _split_text_with_math(cleaned)
|
|
172
|
+
cleaned = _normalize_unknown_commands(cleaned)
|
|
173
|
+
return cleaned
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _collapse_spaced_text(text: str) -> str:
|
|
177
|
+
tokens = text.split()
|
|
178
|
+
if not tokens:
|
|
179
|
+
return text
|
|
180
|
+
out: list[str] = []
|
|
181
|
+
i = 0
|
|
182
|
+
while i < len(tokens):
|
|
183
|
+
if len(tokens[i]) == 1:
|
|
184
|
+
j = i
|
|
185
|
+
while j < len(tokens) and len(tokens[j]) == 1:
|
|
186
|
+
j += 1
|
|
187
|
+
out.append("".join(tokens[i:j]))
|
|
188
|
+
i = j
|
|
189
|
+
else:
|
|
190
|
+
out.append(tokens[i])
|
|
191
|
+
i += 1
|
|
192
|
+
return " ".join(out)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _collapse_spaced_text_commands(text: str) -> str:
|
|
196
|
+
def replace(match: re.Match[str]) -> str:
|
|
197
|
+
name = match.group(1)
|
|
198
|
+
content = match.group(2)
|
|
199
|
+
collapsed = _collapse_spaced_text(content)
|
|
200
|
+
return f"\\{name}{{{collapsed}}}"
|
|
201
|
+
|
|
202
|
+
return re.sub(r"\\(text|operatorname\*?)\s*{([^}]*)}", replace, text)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _split_text_with_math(text: str) -> str:
|
|
206
|
+
def replace(match: re.Match[str]) -> str:
|
|
207
|
+
content = match.group(1)
|
|
208
|
+
collapsed = _collapse_spaced_text(content)
|
|
209
|
+
token = re.search(r"(\\times|\\frac|\\sum|\\prod|\\left|\\right|\^|_)", collapsed)
|
|
210
|
+
if not token:
|
|
211
|
+
return f"\\text{{{collapsed}}}"
|
|
212
|
+
idx = token.start()
|
|
213
|
+
prefix = collapsed[:idx].rstrip()
|
|
214
|
+
suffix = collapsed[idx:].lstrip()
|
|
215
|
+
if not prefix:
|
|
216
|
+
return suffix
|
|
217
|
+
return f"\\text{{{prefix}}} {suffix}"
|
|
218
|
+
|
|
219
|
+
return re.sub(r"\\text\s*{([^}]*)}", replace, text)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
_KNOWN_LATEX_COMMANDS = {
|
|
223
|
+
"alpha",
|
|
224
|
+
"beta",
|
|
225
|
+
"gamma",
|
|
226
|
+
"delta",
|
|
227
|
+
"epsilon",
|
|
228
|
+
"varepsilon",
|
|
229
|
+
"zeta",
|
|
230
|
+
"eta",
|
|
231
|
+
"theta",
|
|
232
|
+
"vartheta",
|
|
233
|
+
"iota",
|
|
234
|
+
"kappa",
|
|
235
|
+
"lambda",
|
|
236
|
+
"mu",
|
|
237
|
+
"nu",
|
|
238
|
+
"xi",
|
|
239
|
+
"pi",
|
|
240
|
+
"rho",
|
|
241
|
+
"sigma",
|
|
242
|
+
"tau",
|
|
243
|
+
"upsilon",
|
|
244
|
+
"phi",
|
|
245
|
+
"varphi",
|
|
246
|
+
"chi",
|
|
247
|
+
"psi",
|
|
248
|
+
"omega",
|
|
249
|
+
"Gamma",
|
|
250
|
+
"Delta",
|
|
251
|
+
"Theta",
|
|
252
|
+
"Lambda",
|
|
253
|
+
"Xi",
|
|
254
|
+
"Pi",
|
|
255
|
+
"Sigma",
|
|
256
|
+
"Upsilon",
|
|
257
|
+
"Phi",
|
|
258
|
+
"Psi",
|
|
259
|
+
"Omega",
|
|
260
|
+
"sin",
|
|
261
|
+
"cos",
|
|
262
|
+
"tan",
|
|
263
|
+
"cot",
|
|
264
|
+
"sec",
|
|
265
|
+
"csc",
|
|
266
|
+
"arcsin",
|
|
267
|
+
"arccos",
|
|
268
|
+
"arctan",
|
|
269
|
+
"sinh",
|
|
270
|
+
"cosh",
|
|
271
|
+
"tanh",
|
|
272
|
+
"log",
|
|
273
|
+
"ln",
|
|
274
|
+
"exp",
|
|
275
|
+
"min",
|
|
276
|
+
"max",
|
|
277
|
+
"argmin",
|
|
278
|
+
"argmax",
|
|
279
|
+
"sqrt",
|
|
280
|
+
"frac",
|
|
281
|
+
"cdot",
|
|
282
|
+
"times",
|
|
283
|
+
"left",
|
|
284
|
+
"right",
|
|
285
|
+
"lceil",
|
|
286
|
+
"rceil",
|
|
287
|
+
"langle",
|
|
288
|
+
"rangle",
|
|
289
|
+
"lvert",
|
|
290
|
+
"rvert",
|
|
291
|
+
"lVert",
|
|
292
|
+
"rVert",
|
|
293
|
+
"sum",
|
|
294
|
+
"prod",
|
|
295
|
+
"int",
|
|
296
|
+
"lim",
|
|
297
|
+
"infty",
|
|
298
|
+
"partial",
|
|
299
|
+
"nabla",
|
|
300
|
+
"cdots",
|
|
301
|
+
"ldots",
|
|
302
|
+
"text",
|
|
303
|
+
"mathrm",
|
|
304
|
+
"mathbf",
|
|
305
|
+
"mathit",
|
|
306
|
+
"mathcal",
|
|
307
|
+
"mathbb",
|
|
308
|
+
"overline",
|
|
309
|
+
"underline",
|
|
310
|
+
"bar",
|
|
311
|
+
"hat",
|
|
312
|
+
"tilde",
|
|
313
|
+
"vec",
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _normalize_unknown_commands(text: str) -> str:
|
|
318
|
+
def replace(match: re.Match[str]) -> str:
|
|
319
|
+
command = match.group(1)
|
|
320
|
+
if command in _KNOWN_LATEX_COMMANDS or not command[:1].isupper():
|
|
321
|
+
return match.group(0)
|
|
322
|
+
return f"\\text{{{command}}}"
|
|
323
|
+
|
|
324
|
+
return re.sub(r"\\([A-Za-z]+)", replace, text)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _validate_pylatex(text: str) -> str | None:
|
|
328
|
+
require_pylatexenc()
|
|
329
|
+
try:
|
|
330
|
+
walker = LatexWalker(text)
|
|
331
|
+
walker.get_latex_nodes()
|
|
332
|
+
except LatexWalkerError as exc:
|
|
333
|
+
return str(exc)
|
|
334
|
+
except Exception as exc: # pragma: no cover - safety net
|
|
335
|
+
return str(exc)
|
|
336
|
+
return None
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class NodeKatexValidator:
|
|
340
|
+
def __init__(self, node_path: str, script_path: str) -> None:
|
|
341
|
+
self._node_path = node_path
|
|
342
|
+
self._script_path = script_path
|
|
343
|
+
self._lock = threading.Lock()
|
|
344
|
+
self._process = self._spawn()
|
|
345
|
+
atexit.register(self.close)
|
|
346
|
+
|
|
347
|
+
def _spawn(self) -> subprocess.Popen[str]:
|
|
348
|
+
return subprocess.Popen(
|
|
349
|
+
[self._node_path, self._script_path],
|
|
350
|
+
stdin=subprocess.PIPE,
|
|
351
|
+
stdout=subprocess.PIPE,
|
|
352
|
+
stderr=subprocess.DEVNULL,
|
|
353
|
+
text=True,
|
|
354
|
+
bufsize=1,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
def _ensure_alive(self) -> None:
|
|
358
|
+
if self._process.poll() is not None:
|
|
359
|
+
self._process = self._spawn()
|
|
360
|
+
|
|
361
|
+
def close(self) -> None:
|
|
362
|
+
if self._process.poll() is None:
|
|
363
|
+
self._process.terminate()
|
|
364
|
+
|
|
365
|
+
def validate(self, latex: str, display_mode: bool) -> str | None:
|
|
366
|
+
with self._lock:
|
|
367
|
+
self._ensure_alive()
|
|
368
|
+
payload = {"latex": latex, "opts": {"displayMode": display_mode}}
|
|
369
|
+
try:
|
|
370
|
+
assert self._process.stdin is not None
|
|
371
|
+
assert self._process.stdout is not None
|
|
372
|
+
self._process.stdin.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
373
|
+
self._process.stdin.flush()
|
|
374
|
+
line = self._process.stdout.readline()
|
|
375
|
+
except (BrokenPipeError, OSError) as exc:
|
|
376
|
+
return f"katex validator IO error: {exc}"
|
|
377
|
+
if not line:
|
|
378
|
+
return "katex validator returned empty response"
|
|
379
|
+
try:
|
|
380
|
+
response = json.loads(line)
|
|
381
|
+
except json.JSONDecodeError:
|
|
382
|
+
return "katex validator returned invalid JSON"
|
|
383
|
+
if response.get("ok") is True:
|
|
384
|
+
return None
|
|
385
|
+
return str(response.get("error") or "katex validation failed")
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _ensure_node_validator() -> NodeKatexValidator | None:
|
|
389
|
+
global _KATEX_WARNED, _NODE_VALIDATOR, _NODE_KATEX_READY
|
|
390
|
+
if _NODE_VALIDATOR is not None:
|
|
391
|
+
return _NODE_VALIDATOR
|
|
392
|
+
node_path = shutil.which("node")
|
|
393
|
+
if not node_path:
|
|
394
|
+
if not _KATEX_WARNED:
|
|
395
|
+
logger.warning("node not available; skip KaTeX validation")
|
|
396
|
+
_KATEX_WARNED = True
|
|
397
|
+
return None
|
|
398
|
+
if _NODE_KATEX_READY is None:
|
|
399
|
+
try:
|
|
400
|
+
result = subprocess.run(
|
|
401
|
+
[node_path, "-e", "require('katex')"],
|
|
402
|
+
check=False,
|
|
403
|
+
stdout=subprocess.DEVNULL,
|
|
404
|
+
stderr=subprocess.DEVNULL,
|
|
405
|
+
)
|
|
406
|
+
_NODE_KATEX_READY = result.returncode == 0
|
|
407
|
+
except OSError:
|
|
408
|
+
_NODE_KATEX_READY = False
|
|
409
|
+
if not _NODE_KATEX_READY:
|
|
410
|
+
if not _KATEX_WARNED:
|
|
411
|
+
logger.warning("katex npm package not available; skip KaTeX validation")
|
|
412
|
+
_KATEX_WARNED = True
|
|
413
|
+
return None
|
|
414
|
+
script_path = str((Path(__file__).with_name("katex_check.js")).resolve())
|
|
415
|
+
_NODE_VALIDATOR = NodeKatexValidator(node_path, script_path)
|
|
416
|
+
return _NODE_VALIDATOR
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _validate_katex(text: str, display_mode: bool) -> str | None:
|
|
420
|
+
validator = _ensure_node_validator()
|
|
421
|
+
if validator is None:
|
|
422
|
+
return None
|
|
423
|
+
return validator.validate(text, display_mode)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def validate_formula(text: str, display_mode: bool) -> list[str]:
|
|
427
|
+
errors: list[str] = []
|
|
428
|
+
pylatex_error = _validate_pylatex(text)
|
|
429
|
+
if pylatex_error:
|
|
430
|
+
errors.append(pylatex_error)
|
|
431
|
+
katex_error = _validate_katex(text, display_mode)
|
|
432
|
+
if katex_error:
|
|
433
|
+
errors.append(katex_error)
|
|
434
|
+
return errors
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def apply_replacements(text: str, replacements: list[tuple[int, int, str]]) -> str:
|
|
438
|
+
if not replacements:
|
|
439
|
+
return text
|
|
440
|
+
updated = text
|
|
441
|
+
for start, end, value in sorted(replacements, key=lambda item: item[0], reverse=True):
|
|
442
|
+
updated = updated[:start] + value + updated[end:]
|
|
443
|
+
return updated
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def repair_schema() -> dict[str, Any]:
|
|
447
|
+
return {
|
|
448
|
+
"type": "object",
|
|
449
|
+
"properties": {
|
|
450
|
+
"items": {
|
|
451
|
+
"type": "array",
|
|
452
|
+
"items": {
|
|
453
|
+
"type": "object",
|
|
454
|
+
"properties": {
|
|
455
|
+
"id": {"type": "string"},
|
|
456
|
+
"latex": {"type": "string"},
|
|
457
|
+
},
|
|
458
|
+
"required": ["id", "latex"],
|
|
459
|
+
},
|
|
460
|
+
}
|
|
461
|
+
},
|
|
462
|
+
"required": ["items"],
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def build_repair_messages(issues: list[FormulaIssue]) -> list[dict[str, str]]:
|
|
467
|
+
payload = [
|
|
468
|
+
{
|
|
469
|
+
"id": issue.issue_id,
|
|
470
|
+
"delimiter": issue.span.delimiter,
|
|
471
|
+
"latex": issue.span.content,
|
|
472
|
+
"errors": issue.errors,
|
|
473
|
+
"context": issue.span.context,
|
|
474
|
+
}
|
|
475
|
+
for issue in issues
|
|
476
|
+
]
|
|
477
|
+
system = (
|
|
478
|
+
"You repair LaTeX math expressions. Fix syntax errors only and keep the "
|
|
479
|
+
"mathematical meaning unchanged. Return JSON with key 'items' and each "
|
|
480
|
+
"item containing {\"id\", \"latex\"}. Output JSON only."
|
|
481
|
+
)
|
|
482
|
+
user = json.dumps({"items": payload}, ensure_ascii=False, indent=2)
|
|
483
|
+
return [
|
|
484
|
+
{"role": "system", "content": system},
|
|
485
|
+
{"role": "user", "content": user},
|
|
486
|
+
]
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def iter_batches(items: list[FormulaIssue], batch_size: int) -> Iterable[list[FormulaIssue]]:
|
|
490
|
+
for idx in range(0, len(items), batch_size):
|
|
491
|
+
yield items[idx : idx + batch_size]
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def locate_json_field_start(
|
|
495
|
+
raw_text: str,
|
|
496
|
+
field_value: str,
|
|
497
|
+
search_start: int,
|
|
498
|
+
) -> tuple[int, int]:
|
|
499
|
+
needle = json.dumps(field_value, ensure_ascii=False)
|
|
500
|
+
inner = needle[1:-1] if needle.startswith("\"") and needle.endswith("\"") else needle
|
|
501
|
+
idx = raw_text.find(inner, search_start)
|
|
502
|
+
if idx == -1:
|
|
503
|
+
idx = raw_text.find(needle, search_start)
|
|
504
|
+
if idx == -1:
|
|
505
|
+
return 1, search_start
|
|
506
|
+
line = raw_text.count("\n", 0, idx) + 1
|
|
507
|
+
return line, idx + len(inner)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def strip_wrapping_delimiters(latex: str, delimiter: str) -> str:
|
|
511
|
+
latex = latex.strip()
|
|
512
|
+
if latex.startswith(delimiter) and latex.endswith(delimiter):
|
|
513
|
+
return latex[len(delimiter) : -len(delimiter)].strip()
|
|
514
|
+
return latex
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _parse_repairs(response: str) -> dict[str, str]:
|
|
518
|
+
parsed = parse_json(response)
|
|
519
|
+
items = parsed.get("items", [])
|
|
520
|
+
repairs: dict[str, str] = {}
|
|
521
|
+
if isinstance(items, list):
|
|
522
|
+
for item in items:
|
|
523
|
+
if not isinstance(item, dict):
|
|
524
|
+
continue
|
|
525
|
+
issue_id = item.get("id")
|
|
526
|
+
latex = item.get("latex")
|
|
527
|
+
if isinstance(issue_id, str) and isinstance(latex, str):
|
|
528
|
+
repairs[issue_id] = latex
|
|
529
|
+
return repairs
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
async def repair_batch(
|
|
533
|
+
issues: list[FormulaIssue],
|
|
534
|
+
provider,
|
|
535
|
+
model_name: str,
|
|
536
|
+
api_key: str | None,
|
|
537
|
+
timeout: float,
|
|
538
|
+
max_retries: int,
|
|
539
|
+
client: httpx.AsyncClient,
|
|
540
|
+
) -> tuple[dict[str, str], str | None]:
|
|
541
|
+
messages = build_repair_messages(issues)
|
|
542
|
+
schema = repair_schema()
|
|
543
|
+
last_error: str | None = None
|
|
544
|
+
for attempt in range(max_retries + 1):
|
|
545
|
+
try:
|
|
546
|
+
response = await call_provider(
|
|
547
|
+
provider,
|
|
548
|
+
model_name,
|
|
549
|
+
messages,
|
|
550
|
+
schema,
|
|
551
|
+
api_key,
|
|
552
|
+
timeout,
|
|
553
|
+
provider.structured_mode,
|
|
554
|
+
client,
|
|
555
|
+
max_tokens=provider.max_tokens,
|
|
556
|
+
)
|
|
557
|
+
repairs = _parse_repairs(response)
|
|
558
|
+
return repairs, None
|
|
559
|
+
except (ValueError, TypeError) as exc:
|
|
560
|
+
last_error = f"parse_error: {exc}"
|
|
561
|
+
if attempt < max_retries:
|
|
562
|
+
await asyncio.sleep(backoff_delay(1.0, attempt + 1, 20.0))
|
|
563
|
+
continue
|
|
564
|
+
return {}, last_error
|
|
565
|
+
except ProviderError as exc:
|
|
566
|
+
last_error = str(exc)
|
|
567
|
+
if exc.retryable and attempt < max_retries:
|
|
568
|
+
await asyncio.sleep(backoff_delay(1.0, attempt + 1, 20.0))
|
|
569
|
+
continue
|
|
570
|
+
return {}, last_error
|
|
571
|
+
except Exception as exc: # pragma: no cover - safety net
|
|
572
|
+
last_error = str(exc)
|
|
573
|
+
if attempt < max_retries:
|
|
574
|
+
await asyncio.sleep(backoff_delay(1.0, attempt + 1, 20.0))
|
|
575
|
+
continue
|
|
576
|
+
return {}, last_error
|
|
577
|
+
return {}, last_error
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
async def fix_math_text(
|
|
581
|
+
text: str,
|
|
582
|
+
file_path: str,
|
|
583
|
+
line_offset: int,
|
|
584
|
+
field_path: str | None,
|
|
585
|
+
item_index: int | None,
|
|
586
|
+
provider,
|
|
587
|
+
model_name: str,
|
|
588
|
+
api_key: str | None,
|
|
589
|
+
timeout: float,
|
|
590
|
+
max_retries: int,
|
|
591
|
+
batch_size: int,
|
|
592
|
+
context_chars: int,
|
|
593
|
+
client: httpx.AsyncClient,
|
|
594
|
+
stats: MathFixStats,
|
|
595
|
+
repair_enabled: bool = True,
|
|
596
|
+
spans: list[FormulaSpan] | None = None,
|
|
597
|
+
progress_cb: Callable[[], None] | None = None,
|
|
598
|
+
) -> tuple[str, list[dict[str, Any]]]:
|
|
599
|
+
replacements: list[tuple[int, int, str]] = []
|
|
600
|
+
issues: list[FormulaIssue] = []
|
|
601
|
+
if spans is None:
|
|
602
|
+
spans = extract_math_spans(text, context_chars)
|
|
603
|
+
stats.formulas_total += len(spans)
|
|
604
|
+
file_id = short_hash(file_path)
|
|
605
|
+
for idx, span in enumerate(spans):
|
|
606
|
+
display_mode = span.delimiter == "$$"
|
|
607
|
+
original = span.content
|
|
608
|
+
errors = validate_formula(original, display_mode)
|
|
609
|
+
cleaned = original
|
|
610
|
+
if errors:
|
|
611
|
+
candidate = cleanup_formula(original)
|
|
612
|
+
if candidate != original:
|
|
613
|
+
candidate_errors = validate_formula(candidate, display_mode)
|
|
614
|
+
if not candidate_errors:
|
|
615
|
+
stats.formulas_cleaned += 1
|
|
616
|
+
wrapped = f"{span.delimiter}{candidate}{span.delimiter}"
|
|
617
|
+
replacements.append((span.start, span.end, wrapped))
|
|
618
|
+
if progress_cb:
|
|
619
|
+
progress_cb()
|
|
620
|
+
continue
|
|
621
|
+
cleaned = candidate
|
|
622
|
+
errors = candidate_errors
|
|
623
|
+
|
|
624
|
+
stats.formulas_invalid += 1
|
|
625
|
+
issue_id = f"{file_id}:{idx}"
|
|
626
|
+
issues.append(
|
|
627
|
+
FormulaIssue(
|
|
628
|
+
issue_id=issue_id,
|
|
629
|
+
span=span,
|
|
630
|
+
errors=errors,
|
|
631
|
+
cleaned=cleaned,
|
|
632
|
+
field_path=field_path,
|
|
633
|
+
item_index=item_index,
|
|
634
|
+
)
|
|
635
|
+
)
|
|
636
|
+
if progress_cb:
|
|
637
|
+
progress_cb()
|
|
638
|
+
|
|
639
|
+
error_records: list[dict[str, Any]] = []
|
|
640
|
+
if issues and repair_enabled:
|
|
641
|
+
for batch in iter_batches(issues, batch_size):
|
|
642
|
+
repairs, error = await repair_batch(
|
|
643
|
+
batch,
|
|
644
|
+
provider,
|
|
645
|
+
model_name,
|
|
646
|
+
api_key,
|
|
647
|
+
timeout,
|
|
648
|
+
max_retries,
|
|
649
|
+
client,
|
|
650
|
+
)
|
|
651
|
+
if error:
|
|
652
|
+
for issue in batch:
|
|
653
|
+
stats.formulas_failed += 1
|
|
654
|
+
error_records.append(
|
|
655
|
+
{
|
|
656
|
+
"path": file_path,
|
|
657
|
+
"line": line_offset + issue.span.line - 1,
|
|
658
|
+
"delimiter": issue.span.delimiter,
|
|
659
|
+
"latex": issue.span.content,
|
|
660
|
+
"errors": issue.errors + [f"llm_error: {error}"],
|
|
661
|
+
"field_path": issue.field_path,
|
|
662
|
+
"item_index": issue.item_index,
|
|
663
|
+
}
|
|
664
|
+
)
|
|
665
|
+
continue
|
|
666
|
+
|
|
667
|
+
for issue in batch:
|
|
668
|
+
repaired = repairs.get(issue.issue_id)
|
|
669
|
+
if not repaired:
|
|
670
|
+
stats.formulas_failed += 1
|
|
671
|
+
error_records.append(
|
|
672
|
+
{
|
|
673
|
+
"path": file_path,
|
|
674
|
+
"line": line_offset + issue.span.line - 1,
|
|
675
|
+
"delimiter": issue.span.delimiter,
|
|
676
|
+
"latex": issue.span.content,
|
|
677
|
+
"errors": issue.errors + ["llm_missing_output"],
|
|
678
|
+
"field_path": issue.field_path,
|
|
679
|
+
"item_index": issue.item_index,
|
|
680
|
+
}
|
|
681
|
+
)
|
|
682
|
+
continue
|
|
683
|
+
repaired = strip_wrapping_delimiters(repaired, issue.span.delimiter)
|
|
684
|
+
cleaned = cleanup_formula(repaired)
|
|
685
|
+
errors = validate_formula(cleaned, issue.span.delimiter == "$$")
|
|
686
|
+
if errors:
|
|
687
|
+
stats.formulas_failed += 1
|
|
688
|
+
error_records.append(
|
|
689
|
+
{
|
|
690
|
+
"path": file_path,
|
|
691
|
+
"line": line_offset + issue.span.line - 1,
|
|
692
|
+
"delimiter": issue.span.delimiter,
|
|
693
|
+
"latex": issue.span.content,
|
|
694
|
+
"errors": errors,
|
|
695
|
+
"field_path": issue.field_path,
|
|
696
|
+
"item_index": issue.item_index,
|
|
697
|
+
}
|
|
698
|
+
)
|
|
699
|
+
continue
|
|
700
|
+
stats.formulas_repaired += 1
|
|
701
|
+
wrapped = f"{issue.span.delimiter}{cleaned}{issue.span.delimiter}"
|
|
702
|
+
replacements.append((issue.span.start, issue.span.end, wrapped))
|
|
703
|
+
elif issues:
|
|
704
|
+
for issue in issues:
|
|
705
|
+
stats.formulas_failed += 1
|
|
706
|
+
error_records.append(
|
|
707
|
+
{
|
|
708
|
+
"path": file_path,
|
|
709
|
+
"line": line_offset + issue.span.line - 1,
|
|
710
|
+
"delimiter": issue.span.delimiter,
|
|
711
|
+
"latex": issue.span.content,
|
|
712
|
+
"errors": issue.errors + ["validation_only"],
|
|
713
|
+
"field_path": issue.field_path,
|
|
714
|
+
"item_index": issue.item_index,
|
|
715
|
+
}
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
updated = apply_replacements(text, replacements)
|
|
719
|
+
return updated, error_records
|