visual-parser 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- visual_parser/__init__.py +20 -0
- visual_parser/__main__.py +8 -0
- visual_parser/cli.py +230 -0
- visual_parser/cli_main.py +223 -0
- visual_parser/config.py +168 -0
- visual_parser/figure_describer.py +218 -0
- visual_parser/jsonl_writer.py +102 -0
- visual_parser/metadata_extractor.py +94 -0
- visual_parser/nougat_engine.py +222 -0
- visual_parser/pdf_tracker.py +105 -0
- visual_parser/pipeline.py +255 -0
- visual_parser/prompts.py +98 -0
- visual_parser/text_extractor.py +396 -0
- visual_parser/vision_llm.py +269 -0
- visual_parser-1.0.0.dist-info/METADATA +191 -0
- visual_parser-1.0.0.dist-info/RECORD +19 -0
- visual_parser-1.0.0.dist-info/WHEEL +5 -0
- visual_parser-1.0.0.dist-info/entry_points.txt +2 -0
- visual_parser-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""
|
|
2
|
+
figure_describer.py — Rasterise every page of each PDF at high DPI and send
|
|
3
|
+
each page image to a Vision LLM for figure extraction.
|
|
4
|
+
|
|
5
|
+
Extracted and cb-decoupled from the inner ``describe_figures_for_new_pdfs``
|
|
6
|
+
function in PDFAnalyser.py.
|
|
7
|
+
|
|
8
|
+
Output
|
|
9
|
+
------
|
|
10
|
+
One record per figure (or per page that contains at least one figure) is
|
|
11
|
+
appended to ``02_visuals_kb.jsonl`` in *output_dir*:
|
|
12
|
+
|
|
13
|
+
{
|
|
14
|
+
"source": "myreport.pdf",
|
|
15
|
+
"page": 3,
|
|
16
|
+
"document_id": "a1b2c3d4e5f6g7h8",
|
|
17
|
+
"figure_index": 0,
|
|
18
|
+
"figure_id": "a1b2c3d4e5f6g7h8:p3:f0",
|
|
19
|
+
"description": "**Subject:** ..."
|
|
20
|
+
}
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import json
|
|
26
|
+
import logging
|
|
27
|
+
import os
|
|
28
|
+
import re
|
|
29
|
+
from collections import defaultdict
|
|
30
|
+
from typing import Any, Dict, List, Optional
|
|
31
|
+
|
|
32
|
+
import fitz # PyMuPDF
|
|
33
|
+
|
|
34
|
+
from visual_parser.jsonl_writer import append_to_jsonl, make_document_id
|
|
35
|
+
from visual_parser.prompts import FIGURE_PROMPT
|
|
36
|
+
from visual_parser.vision_llm import call_vision_llm
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
# JSON response parser
|
|
43
|
+
# ---------------------------------------------------------------------------
|
|
44
|
+
|
|
45
|
+
def _parse_llm_response(
|
|
46
|
+
raw: str,
|
|
47
|
+
pdf_name: str,
|
|
48
|
+
page_number: Optional[int] = None,
|
|
49
|
+
) -> Optional[List[Dict]]:
|
|
50
|
+
"""
|
|
51
|
+
Parse the Vision LLM's JSON list response.
|
|
52
|
+
|
|
53
|
+
Strips markdown fences, tries ``json.loads``, then falls back to a regex
|
|
54
|
+
search for a JSON array if the model wraps it in prose.
|
|
55
|
+
"""
|
|
56
|
+
body = raw.strip()
|
|
57
|
+
for fence in ("```json", "```"):
|
|
58
|
+
body = body.replace(fence, "")
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
return json.loads(body)
|
|
62
|
+
except json.JSONDecodeError:
|
|
63
|
+
match = re.search(r"(\[\s*\{.*?\}\s*\])", body, re.S)
|
|
64
|
+
if match:
|
|
65
|
+
try:
|
|
66
|
+
return json.loads(match.group(1))
|
|
67
|
+
except json.JSONDecodeError:
|
|
68
|
+
pass
|
|
69
|
+
label = f"{pdf_name} p{page_number}" if page_number else pdf_name
|
|
70
|
+
logger.warning("Could not parse JSON for %s: %r", label, body[:200])
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ---------------------------------------------------------------------------
|
|
75
|
+
# Main function
|
|
76
|
+
# ---------------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
def describe_figures_for_new_pdfs(
|
|
79
|
+
new_pdf_paths: List[str],
|
|
80
|
+
output_dir: str,
|
|
81
|
+
vision_provider: str,
|
|
82
|
+
vision_api_key: str,
|
|
83
|
+
vision_model: str,
|
|
84
|
+
vision_detail: str = "low",
|
|
85
|
+
raster_dpi: int = 200,
|
|
86
|
+
figure_prompt: str = FIGURE_PROMPT,
|
|
87
|
+
reasoning_effort: Optional[str] = "medium",
|
|
88
|
+
) -> None:
|
|
89
|
+
"""
|
|
90
|
+
For each PDF in *new_pdf_paths*, rasterise every page at *raster_dpi* DPI,
|
|
91
|
+
call the Vision LLM page-by-page, parse the figure descriptions, and
|
|
92
|
+
append the results to ``02_visuals_kb.jsonl`` in *output_dir*.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
new_pdf_paths: Full paths of PDFs to describe.
|
|
96
|
+
output_dir: Directory where ``02_visuals_kb.jsonl`` is written.
|
|
97
|
+
vision_provider: ``'gpt'`` or ``'gemini'``.
|
|
98
|
+
vision_api_key: API key for the chosen provider.
|
|
99
|
+
vision_model: Vision model name string.
|
|
100
|
+
vision_detail: Image detail level (GPT only).
|
|
101
|
+
raster_dpi: DPI used when rasterising pages. 200 DPI gives a
|
|
102
|
+
good balance between quality and API payload size.
|
|
103
|
+
figure_prompt: The instruction prompt sent with each page image.
|
|
104
|
+
Override this to customise for a specific domain.
|
|
105
|
+
"""
|
|
106
|
+
# -----------------------------------------------------------------------
|
|
107
|
+
# Step 1 – Rasterise every page of every new PDF
|
|
108
|
+
# -----------------------------------------------------------------------
|
|
109
|
+
page_images: List[Dict[str, Any]] = []
|
|
110
|
+
|
|
111
|
+
for pdf_full_path in new_pdf_paths:
|
|
112
|
+
pdf_name = os.path.basename(pdf_full_path)
|
|
113
|
+
try:
|
|
114
|
+
doc = fitz.open(pdf_full_path)
|
|
115
|
+
for page_index, page in enumerate(doc):
|
|
116
|
+
pix = page.get_pixmap(dpi=raster_dpi)
|
|
117
|
+
page_images.append({
|
|
118
|
+
"pdf": pdf_name,
|
|
119
|
+
"page": page_index + 1,
|
|
120
|
+
"bytes": pix.tobytes("png"),
|
|
121
|
+
})
|
|
122
|
+
doc.close()
|
|
123
|
+
except Exception as exc:
|
|
124
|
+
logger.error("Error rasterising %s: %s", pdf_name, exc)
|
|
125
|
+
|
|
126
|
+
if not page_images:
|
|
127
|
+
logger.info("No pages to describe (all PDFs failed to rasterise).")
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
# -----------------------------------------------------------------------
|
|
131
|
+
# Step 2 – Group page images by PDF name
|
|
132
|
+
# -----------------------------------------------------------------------
|
|
133
|
+
pages_by_pdf: Dict[str, List[Dict]] = defaultdict(list)
|
|
134
|
+
for record in page_images:
|
|
135
|
+
pages_by_pdf[record["pdf"]].append(record)
|
|
136
|
+
|
|
137
|
+
# -----------------------------------------------------------------------
|
|
138
|
+
# Step 3 – Call Vision LLM once per page
|
|
139
|
+
# -----------------------------------------------------------------------
|
|
140
|
+
descriptions_by_page: Dict[tuple, List[str]] = {}
|
|
141
|
+
|
|
142
|
+
for pdf_name, image_records in pages_by_pdf.items():
|
|
143
|
+
per_pdf_count = 0
|
|
144
|
+
|
|
145
|
+
for record in image_records:
|
|
146
|
+
page_number = record["page"]
|
|
147
|
+
image_bytes = record["bytes"]
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
raw_response = call_vision_llm(
|
|
151
|
+
images=[image_bytes],
|
|
152
|
+
prompt=figure_prompt,
|
|
153
|
+
provider=vision_provider,
|
|
154
|
+
api_key=vision_api_key,
|
|
155
|
+
model=vision_model,
|
|
156
|
+
detail=vision_detail,
|
|
157
|
+
reasoning_effort=reasoning_effort,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
captions = _parse_llm_response(raw_response, pdf_name, page_number)
|
|
161
|
+
|
|
162
|
+
# Normalise: the model should return a list, but sometimes
|
|
163
|
+
# returns a single dict for single-figure pages.
|
|
164
|
+
if isinstance(captions, dict):
|
|
165
|
+
captions = [captions]
|
|
166
|
+
|
|
167
|
+
if not isinstance(captions, list):
|
|
168
|
+
logger.warning(
|
|
169
|
+
"Vision LLM returned non-list output for %s page %d",
|
|
170
|
+
pdf_name, page_number,
|
|
171
|
+
)
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
for caption in captions:
|
|
175
|
+
if not isinstance(caption, dict):
|
|
176
|
+
continue
|
|
177
|
+
description = caption.get("description")
|
|
178
|
+
if description is None:
|
|
179
|
+
continue
|
|
180
|
+
key = (pdf_name, page_number)
|
|
181
|
+
descriptions_by_page.setdefault(key, []).append(description)
|
|
182
|
+
per_pdf_count += 1
|
|
183
|
+
|
|
184
|
+
except Exception as exc:
|
|
185
|
+
logger.error(
|
|
186
|
+
"Vision LLM failed for %s page %d: %s",
|
|
187
|
+
pdf_name, page_number, exc,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
logger.info("[FIGURES] %s: %d figure(s) extracted.", pdf_name, per_pdf_count)
|
|
191
|
+
|
|
192
|
+
total = sum(len(v) for v in descriptions_by_page.values())
|
|
193
|
+
logger.info("Total figures captured: %d across %d PDF(s).", total,
|
|
194
|
+
len({k[0] for k in descriptions_by_page}))
|
|
195
|
+
|
|
196
|
+
# -----------------------------------------------------------------------
|
|
197
|
+
# Step 4 – Write figure descriptions to 02_visuals_kb.jsonl
|
|
198
|
+
# -----------------------------------------------------------------------
|
|
199
|
+
figure_rows: List[Dict] = []
|
|
200
|
+
|
|
201
|
+
for (pdf_name, page_number), descriptions in descriptions_by_page.items():
|
|
202
|
+
document_id = make_document_id(pdf_name)
|
|
203
|
+
for fig_idx, description in enumerate(descriptions):
|
|
204
|
+
figure_rows.append({
|
|
205
|
+
"source": pdf_name,
|
|
206
|
+
"page": page_number,
|
|
207
|
+
"document_id": document_id,
|
|
208
|
+
"figure_index": fig_idx,
|
|
209
|
+
"figure_id": f"{document_id}:p{page_number}:f{fig_idx}",
|
|
210
|
+
"description": description,
|
|
211
|
+
})
|
|
212
|
+
|
|
213
|
+
if figure_rows:
|
|
214
|
+
figures_path = os.path.join(output_dir, "02_visuals_kb.jsonl")
|
|
215
|
+
append_to_jsonl(figures_path, figure_rows)
|
|
216
|
+
print(f"[FIGURES] Wrote {len(figure_rows)} figure record(s) to 02_visuals_kb.jsonl.")
|
|
217
|
+
else:
|
|
218
|
+
logger.info("No figures detected — 02_visuals_kb.jsonl not updated.")
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""
|
|
2
|
+
jsonl_writer.py — Atomic JSONL append helper and stable document-ID generator.
|
|
3
|
+
|
|
4
|
+
Consolidated from the two duplicate copies that existed in
|
|
5
|
+
utils/nougat_helpers.py and PDFAnalyser.py.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
from typing import Dict, List
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def make_document_id(source: str) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Return a 16-character hex SHA-1 digest of the PDF basename.
|
|
22
|
+
|
|
23
|
+
The ID is stable across runs as long as the filename doesn't change,
|
|
24
|
+
which lets downstream systems deduplicate without re-reading JSONL files.
|
|
25
|
+
"""
|
|
26
|
+
try:
|
|
27
|
+
return hashlib.sha1(source.encode("utf-8")).hexdigest()[:16]
|
|
28
|
+
except Exception as exc:
|
|
29
|
+
logger.warning("Could not hash source %r: %s — using raw name as fallback.", source, exc)
|
|
30
|
+
return source
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def append_to_jsonl(jsonl_file: str, new_data: List[Dict]) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Safely append *new_data* to a JSON Lines file.
|
|
36
|
+
|
|
37
|
+
- Creates the file (and any missing parent directories) if needed.
|
|
38
|
+
- Skips individual rows that cannot be serialised without aborting the
|
|
39
|
+
entire write.
|
|
40
|
+
- Never corrupts existing content: each row is appended as a complete
|
|
41
|
+
``\\n``-terminated JSON line.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
jsonl_file: Absolute or relative path to the target ``.jsonl`` file.
|
|
45
|
+
new_data: List of dicts to write (one per line).
|
|
46
|
+
"""
|
|
47
|
+
if not isinstance(new_data, list):
|
|
48
|
+
logger.warning(
|
|
49
|
+
"append_to_jsonl: new_data must be a list, got %s — skipping.",
|
|
50
|
+
type(new_data).__name__,
|
|
51
|
+
)
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
parent = os.path.dirname(jsonl_file)
|
|
56
|
+
if parent:
|
|
57
|
+
os.makedirs(parent, exist_ok=True)
|
|
58
|
+
|
|
59
|
+
with open(jsonl_file, "a", encoding="utf-8") as fh:
|
|
60
|
+
for row in new_data:
|
|
61
|
+
if not isinstance(row, dict):
|
|
62
|
+
logger.warning("Skipping non-dict JSONL entry: %s", type(row).__name__)
|
|
63
|
+
continue
|
|
64
|
+
try:
|
|
65
|
+
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
66
|
+
except (TypeError, ValueError) as exc:
|
|
67
|
+
logger.warning("Failed to serialise row — skipping. Error: %s", exc)
|
|
68
|
+
|
|
69
|
+
except OSError as exc:
|
|
70
|
+
logger.error("File-system error writing %s: %s", jsonl_file, exc)
|
|
71
|
+
except Exception as exc:
|
|
72
|
+
logger.error("Unexpected error writing %s: %s", jsonl_file, exc)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def read_jsonl(jsonl_path: str) -> List[Dict]:
|
|
76
|
+
"""
|
|
77
|
+
Read all valid JSON lines from *jsonl_path*.
|
|
78
|
+
|
|
79
|
+
Corrupted lines are skipped with a warning; the rest are returned intact.
|
|
80
|
+
"""
|
|
81
|
+
rows: List[Dict] = []
|
|
82
|
+
if not os.path.exists(jsonl_path):
|
|
83
|
+
logger.warning("JSONL file not found: %s", jsonl_path)
|
|
84
|
+
return rows
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
with open(jsonl_path, "r", encoding="utf-8") as fh:
|
|
88
|
+
for line_num, line in enumerate(fh, start=1):
|
|
89
|
+
line = line.strip()
|
|
90
|
+
if not line:
|
|
91
|
+
continue
|
|
92
|
+
try:
|
|
93
|
+
rows.append(json.loads(line))
|
|
94
|
+
except json.JSONDecodeError as exc:
|
|
95
|
+
logger.warning(
|
|
96
|
+
"Skipping corrupted JSONL line %d in %s: %s",
|
|
97
|
+
line_num, jsonl_path, exc,
|
|
98
|
+
)
|
|
99
|
+
except Exception as exc:
|
|
100
|
+
logger.error("Error reading %s: %s", jsonl_path, exc)
|
|
101
|
+
|
|
102
|
+
return rows
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""
|
|
2
|
+
metadata_extractor.py — Extract document-level metadata (title, authors, DOI …)
|
|
3
|
+
from the front pages of a PDF using a vision LLM.
|
|
4
|
+
|
|
5
|
+
Extracted and cb-decoupled from utils/general_utilities.py.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
13
|
+
|
|
14
|
+
import fitz # PyMuPDF
|
|
15
|
+
|
|
16
|
+
from visual_parser.prompts import METADATA_PROMPT_TEMPLATE
|
|
17
|
+
from visual_parser.vision_llm import call_vision_llm
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def extract_pdf_metadata(
|
|
23
|
+
pdf_path: str,
|
|
24
|
+
vision_provider: str,
|
|
25
|
+
vision_api_key: str,
|
|
26
|
+
vision_model: str,
|
|
27
|
+
num_pages: int = 2,
|
|
28
|
+
vision_detail: str = "auto",
|
|
29
|
+
reasoning_effort: Optional[str] = "medium",
|
|
30
|
+
) -> Dict[str, Any]:
|
|
31
|
+
"""
|
|
32
|
+
Rasterize the first *num_pages* of *pdf_path*, send them to the Vision LLM,
|
|
33
|
+
and parse the JSON metadata response.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
pdf_path: Absolute path to the PDF file.
|
|
37
|
+
vision_provider: ``'gpt'`` or ``'gemini'``.
|
|
38
|
+
vision_api_key: API key for the chosen provider.
|
|
39
|
+
vision_model: Model name string.
|
|
40
|
+
num_pages: How many front pages to send (default: 2).
|
|
41
|
+
vision_detail: Image detail level for GPT ('low', 'high', 'auto').
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Dict with any of: title, authors, publication_date, report_number,
|
|
45
|
+
doi, keywords — plus a ``_source`` entry with the PDF basename.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
RuntimeError on unrecoverable errors (PDF open failure, no valid JSON).
|
|
49
|
+
"""
|
|
50
|
+
# 1) Rasterize front pages
|
|
51
|
+
try:
|
|
52
|
+
doc = fitz.open(pdf_path)
|
|
53
|
+
except Exception as exc:
|
|
54
|
+
raise RuntimeError(f"Failed to open PDF {pdf_path!r}: {exc}") from exc
|
|
55
|
+
|
|
56
|
+
images: List[bytes] = []
|
|
57
|
+
for i in range(min(num_pages, doc.page_count)):
|
|
58
|
+
try:
|
|
59
|
+
pix = doc.load_page(i).get_pixmap(dpi=200)
|
|
60
|
+
images.append(pix.tobytes("png"))
|
|
61
|
+
except Exception as exc:
|
|
62
|
+
logger.warning("Skipping page %d of %s: %s", i, pdf_path, exc)
|
|
63
|
+
doc.close()
|
|
64
|
+
|
|
65
|
+
if not images:
|
|
66
|
+
raise RuntimeError(f"No pages rendered from {pdf_path!r}")
|
|
67
|
+
|
|
68
|
+
# 2) Build prompt
|
|
69
|
+
prompt = METADATA_PROMPT_TEMPLATE.format(num_pages=num_pages)
|
|
70
|
+
|
|
71
|
+
# 3) Call Vision LLM
|
|
72
|
+
raw = call_vision_llm(
|
|
73
|
+
images=images,
|
|
74
|
+
prompt=prompt,
|
|
75
|
+
provider=vision_provider,
|
|
76
|
+
api_key=vision_api_key,
|
|
77
|
+
model=vision_model,
|
|
78
|
+
detail=vision_detail,
|
|
79
|
+
reasoning_effort=reasoning_effort,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# 4) Extract and parse JSON substring
|
|
83
|
+
start = raw.find("{")
|
|
84
|
+
end = raw.rfind("}")
|
|
85
|
+
if start < 0 or end < 0 or end <= start:
|
|
86
|
+
raise RuntimeError(f"No JSON found in vision LLM response:\n{raw}")
|
|
87
|
+
|
|
88
|
+
candidate = raw[start: end + 1].strip().strip("```").strip()
|
|
89
|
+
try:
|
|
90
|
+
return json.loads(candidate)
|
|
91
|
+
except json.JSONDecodeError as exc:
|
|
92
|
+
raise RuntimeError(
|
|
93
|
+
f"Failed to parse metadata JSON:\n{candidate}\nError: {exc}"
|
|
94
|
+
) from exc
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""
|
|
2
|
+
nougat_engine.py — Nougat model initialisation, PDF rasterisation, and
|
|
3
|
+
the stopping-criteria classes from the original Nougat paper.
|
|
4
|
+
|
|
5
|
+
Extracted and cleaned from utils/nougat_helpers.py.
|
|
6
|
+
No chatbot, no LangChain, no Dash dependencies.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import io
|
|
12
|
+
import logging
|
|
13
|
+
from collections import defaultdict
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import List, Optional
|
|
16
|
+
|
|
17
|
+
import fitz # PyMuPDF
|
|
18
|
+
import torch
|
|
19
|
+
from PIL import Image
|
|
20
|
+
from transformers import (
|
|
21
|
+
AutoProcessor,
|
|
22
|
+
StoppingCriteria,
|
|
23
|
+
StoppingCriteriaList,
|
|
24
|
+
VisionEncoderDecoderModel,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# ---------------------------------------------------------------------------
|
|
31
|
+
# Model initialisation
|
|
32
|
+
# ---------------------------------------------------------------------------
|
|
33
|
+
|
|
34
|
+
def _normalize_nougat_processor(processor) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Apply compatibility fixes for processor configs across transformers versions.
|
|
37
|
+
|
|
38
|
+
Some newer processor/image-processor builds reject ``None`` for boolean
|
|
39
|
+
fields that older Nougat configs may omit. Normalize those fields to safe
|
|
40
|
+
defaults after loading the processor.
|
|
41
|
+
"""
|
|
42
|
+
image_processor = getattr(processor, "image_processor", None)
|
|
43
|
+
if image_processor is None:
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
fixed_fields = []
|
|
47
|
+
for attr_name in dir(image_processor):
|
|
48
|
+
if not attr_name.startswith("do_"):
|
|
49
|
+
continue
|
|
50
|
+
try:
|
|
51
|
+
attr_value = getattr(image_processor, attr_name)
|
|
52
|
+
except Exception:
|
|
53
|
+
continue
|
|
54
|
+
if attr_value is None:
|
|
55
|
+
try:
|
|
56
|
+
setattr(image_processor, attr_name, False)
|
|
57
|
+
fixed_fields.append(attr_name)
|
|
58
|
+
except Exception:
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
if fixed_fields:
|
|
62
|
+
logger.warning(
|
|
63
|
+
"[NOUGAT] Normalized image-processor boolean flags with None values: %s",
|
|
64
|
+
", ".join(sorted(fixed_fields)),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def NougatInitializer(model_name: str = "facebook/nougat-small"):
|
|
68
|
+
"""
|
|
69
|
+
Load the Nougat processor and model onto the best available device.
|
|
70
|
+
|
|
71
|
+
If ``HF_TOKEN`` is present in the environment (e.g. loaded from .env),
|
|
72
|
+
the function authenticates with the HuggingFace Hub before downloading
|
|
73
|
+
weights, which suppresses the unauthenticated-request warning and gives
|
|
74
|
+
higher rate limits.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
(processor, model, device) tuple ready for inference.
|
|
78
|
+
"""
|
|
79
|
+
import os
|
|
80
|
+
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
|
81
|
+
if hf_token:
|
|
82
|
+
try:
|
|
83
|
+
import huggingface_hub
|
|
84
|
+
huggingface_hub.login(token=hf_token, add_to_git_credential=False)
|
|
85
|
+
logger.info("[NOUGAT] Authenticated with HuggingFace Hub.")
|
|
86
|
+
except Exception as exc:
|
|
87
|
+
logger.warning("[NOUGAT] HF login attempt failed (non-fatal): %s", exc)
|
|
88
|
+
else:
|
|
89
|
+
logger.info(
|
|
90
|
+
"[NOUGAT] No HF_TOKEN found — downloads may be rate-limited. "
|
|
91
|
+
"Add HF_TOKEN to your .env to silence this."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
print(f"[NOUGAT] Loading model: {model_name} …")
|
|
95
|
+
processor = AutoProcessor.from_pretrained(model_name, token=hf_token)
|
|
96
|
+
_normalize_nougat_processor(processor)
|
|
97
|
+
model = VisionEncoderDecoderModel.from_pretrained(model_name, token=hf_token)
|
|
98
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
99
|
+
model.to(device)
|
|
100
|
+
print(f"[NOUGAT] Model loaded on {device}.")
|
|
101
|
+
return processor, model, device
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# ---------------------------------------------------------------------------
|
|
105
|
+
# PDF rasterisation
|
|
106
|
+
# ---------------------------------------------------------------------------
|
|
107
|
+
|
|
108
|
+
def RasterizePaper(
|
|
109
|
+
pdf: Path | str,
|
|
110
|
+
outpath: Optional[Path] = None,
|
|
111
|
+
dpi: int = 96,
|
|
112
|
+
return_pil: bool = False,
|
|
113
|
+
pages: Optional[range] = None,
|
|
114
|
+
) -> Optional[List[io.BytesIO]]:
|
|
115
|
+
"""
|
|
116
|
+
Rasterize each page of *pdf* to PNG.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
pdf: Path to the PDF file.
|
|
120
|
+
outpath: Directory to write ``01.png``, ``02.png`` … files.
|
|
121
|
+
When *None*, ``return_pil`` is forced to True.
|
|
122
|
+
dpi: Rendering resolution (96 dpi for Nougat, 200 for figures).
|
|
123
|
+
return_pil: Return a list of :class:`io.BytesIO` objects instead of
|
|
124
|
+
writing files.
|
|
125
|
+
pages: Subset of page indices to process. Defaults to all pages.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
List of :class:`io.BytesIO` objects when ``return_pil=True``,
|
|
129
|
+
otherwise *None* (files written to *outpath*).
|
|
130
|
+
"""
|
|
131
|
+
if outpath is None:
|
|
132
|
+
return_pil = True
|
|
133
|
+
|
|
134
|
+
pillow_images: List[io.BytesIO] = []
|
|
135
|
+
try:
|
|
136
|
+
doc = fitz.open(pdf) if isinstance(pdf, (str, Path)) else pdf
|
|
137
|
+
if pages is None:
|
|
138
|
+
pages = range(len(doc))
|
|
139
|
+
for i in pages:
|
|
140
|
+
page_bytes: bytes = doc[i].get_pixmap(dpi=dpi).pil_tobytes(format="PNG")
|
|
141
|
+
if return_pil:
|
|
142
|
+
pillow_images.append(io.BytesIO(page_bytes))
|
|
143
|
+
else:
|
|
144
|
+
with (outpath / ("%02d.png" % (i + 1))).open("wb") as f:
|
|
145
|
+
f.write(page_bytes)
|
|
146
|
+
except Exception as exc:
|
|
147
|
+
logger.error("Error rasterizing PDF %s: %s", pdf, exc)
|
|
148
|
+
|
|
149
|
+
return pillow_images if return_pil else None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# ---------------------------------------------------------------------------
|
|
153
|
+
# Nougat stopping criteria (from the original Nougat repository)
|
|
154
|
+
# ---------------------------------------------------------------------------
|
|
155
|
+
|
|
156
|
+
class RunningVarTorch:
|
|
157
|
+
"""Maintains a sliding-window variance for a sequence of tensors."""
|
|
158
|
+
|
|
159
|
+
def __init__(self, L: int = 15, norm: bool = False):
|
|
160
|
+
self.values = None
|
|
161
|
+
self.L = L
|
|
162
|
+
self.norm = norm
|
|
163
|
+
|
|
164
|
+
def push(self, x: torch.Tensor) -> None:
|
|
165
|
+
assert x.dim() == 1
|
|
166
|
+
if self.values is None:
|
|
167
|
+
self.values = x[:, None]
|
|
168
|
+
elif self.values.shape[1] < self.L:
|
|
169
|
+
self.values = torch.cat((self.values, x[:, None]), 1)
|
|
170
|
+
else:
|
|
171
|
+
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
|
|
172
|
+
|
|
173
|
+
def variance(self):
|
|
174
|
+
if self.values is None:
|
|
175
|
+
return None
|
|
176
|
+
if self.norm:
|
|
177
|
+
return torch.var(self.values, 1) / self.values.shape[1]
|
|
178
|
+
return torch.var(self.values, 1)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class StoppingCriteriaScores(StoppingCriteria):
|
|
182
|
+
"""
|
|
183
|
+
Stops generation when the variance of the score distribution stabilises —
|
|
184
|
+
as recommended by the Nougat authors to avoid repetition loops.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def __init__(self, threshold: float = 0.015, window_size: int = 200):
|
|
188
|
+
super().__init__()
|
|
189
|
+
self.threshold = threshold
|
|
190
|
+
self.vars = RunningVarTorch(norm=True)
|
|
191
|
+
self.varvars = RunningVarTorch(L=window_size)
|
|
192
|
+
self.stop_inds = defaultdict(int)
|
|
193
|
+
self.stopped = defaultdict(bool)
|
|
194
|
+
self.size = 0
|
|
195
|
+
self.window_size = window_size
|
|
196
|
+
|
|
197
|
+
@torch.no_grad()
|
|
198
|
+
def __call__(
|
|
199
|
+
self,
|
|
200
|
+
input_ids: torch.LongTensor,
|
|
201
|
+
scores: torch.FloatTensor,
|
|
202
|
+
) -> bool:
|
|
203
|
+
last_scores = scores[-1]
|
|
204
|
+
self.vars.push(last_scores.max(1)[0].float().cpu())
|
|
205
|
+
self.varvars.push(self.vars.variance())
|
|
206
|
+
self.size += 1
|
|
207
|
+
if self.size < self.window_size:
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
varvar = self.varvars.variance()
|
|
211
|
+
for b in range(len(last_scores)):
|
|
212
|
+
if varvar[b] < self.threshold:
|
|
213
|
+
if self.stop_inds[b] > 0 and not self.stopped[b]:
|
|
214
|
+
self.stopped[b] = self.stop_inds[b] >= self.size
|
|
215
|
+
else:
|
|
216
|
+
self.stop_inds[b] = int(
|
|
217
|
+
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
self.stop_inds[b] = 0
|
|
221
|
+
self.stopped[b] = False
|
|
222
|
+
return all(self.stopped.values()) and len(self.stopped) > 0
|