ocrcontext 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocrcontext/__init__.py +49 -0
- ocrcontext/analyzer.py +198 -0
- ocrcontext/config.py +49 -0
- ocrcontext/engines/__init__.py +6 -0
- ocrcontext/engines/base.py +45 -0
- ocrcontext/engines/handwriting.py +103 -0
- ocrcontext/engines/paddle.py +264 -0
- ocrcontext/engines/pdf_text.py +126 -0
- ocrcontext/engines/registry.py +67 -0
- ocrcontext/engines/trocr.py +191 -0
- ocrcontext/engines/vision.py +538 -0
- ocrcontext/exceptions.py +45 -0
- ocrcontext/llm/__init__.py +10 -0
- ocrcontext/llm/drift.py +58 -0
- ocrcontext/llm/extractor.py +63 -0
- ocrcontext/llm/formatting.py +39 -0
- ocrcontext/llm/literal_preserve.py +164 -0
- ocrcontext/llm/prompts.py +157 -0
- ocrcontext/llm/refiner.py +114 -0
- ocrcontext/llm/schemas.py +99 -0
- ocrcontext/pipeline.py +162 -0
- ocrcontext/preprocessing/__init__.py +5 -0
- ocrcontext/preprocessing/image.py +177 -0
- ocrcontext/py.typed +0 -0
- ocrcontext/quality.py +76 -0
- ocrcontext/schemas.py +8 -0
- ocrcontext/types.py +55 -0
- ocrcontext/utils/__init__.py +1 -0
- ocrcontext/utils/files.py +172 -0
- ocrcontext/utils/lang.py +77 -0
- ocrcontext-0.1.0.dist-info/METADATA +207 -0
- ocrcontext-0.1.0.dist-info/RECORD +34 -0
- ocrcontext-0.1.0.dist-info/WHEEL +4 -0
- ocrcontext-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""PaddleOCR engine for printed text and scanned documents.
|
|
2
|
+
|
|
3
|
+
Ported from ocr-service/modal_app.py::OCRService — the lazy per-language model
|
|
4
|
+
cache, multi-language *coverage-first* candidate selection, and the line-band
|
|
5
|
+
recovery fallback are preserved exactly. The Modal/GPU plumbing is removed.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import sys
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
from ..exceptions import EngineError, MissingDependencyError
|
|
15
|
+
from ..preprocessing.image import preprocess_image_for_ocr, split_image_into_line_bands
|
|
16
|
+
from ..utils.files import ascii_safe_dir, cleanup_paths, is_ascii
|
|
17
|
+
from ..utils.lang import candidate_langs
|
|
18
|
+
from .base import OcrEngine, PageOcr
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _ensure_ascii_model_cache() -> None:
|
|
22
|
+
"""Point PaddleX/HuggingFace model caches at an ASCII-safe path on Windows.
|
|
23
|
+
|
|
24
|
+
PaddlePaddle's C++ model loader cannot open files whose path contains
|
|
25
|
+
non-ASCII characters (e.g. a non-ASCII Windows username), failing with an
|
|
26
|
+
"attempting to parse an empty input" JSON error. Redirecting the cache to the
|
|
27
|
+
8.3 short path of the home directory aliases the very same files via ASCII.
|
|
28
|
+
Respects any cache env vars the user already set.
|
|
29
|
+
"""
|
|
30
|
+
if sys.platform != "win32":
|
|
31
|
+
return
|
|
32
|
+
home = str(Path.home())
|
|
33
|
+
if is_ascii(home):
|
|
34
|
+
return
|
|
35
|
+
safe_home = ascii_safe_dir(home)
|
|
36
|
+
if not is_ascii(safe_home):
|
|
37
|
+
return # no ASCII short path available; nothing we can safely do
|
|
38
|
+
os.environ.setdefault("PADDLE_PDX_CACHE_HOME", os.path.join(safe_home, ".paddlex"))
|
|
39
|
+
os.environ.setdefault("HF_HOME", os.path.join(safe_home, ".cache", "huggingface"))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _ensure_paddle_runtime_flags() -> None:
|
|
43
|
+
"""Disable oneDNN/MKLDNN on CPU.
|
|
44
|
+
|
|
45
|
+
PaddlePaddle 3.x's new-IR (PIR) executor hits an unimplemented oneDNN op for
|
|
46
|
+
some PP-OCR models on CPU ("ConvertPirAttribute2RuntimeAttribute not
|
|
47
|
+
support"). Turning oneDNN off routes inference through the standard kernels.
|
|
48
|
+
Set as an env FLAG so it applies regardless of constructor support, and
|
|
49
|
+
respects any value the user already chose.
|
|
50
|
+
"""
|
|
51
|
+
os.environ.setdefault("FLAGS_use_mkldnn", "0")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _extract_from_result(result):
|
|
55
|
+
"""Normalize PaddleOCR / PaddleX result objects into (text, scores)."""
|
|
56
|
+
extracted_text = ""
|
|
57
|
+
extracted_scores: list[float] = []
|
|
58
|
+
if not result:
|
|
59
|
+
return extracted_text, extracted_scores
|
|
60
|
+
first_page_result = result[0] if isinstance(result, list) and len(result) > 0 else result
|
|
61
|
+
if hasattr(first_page_result, "keys") and "rec_texts" in first_page_result:
|
|
62
|
+
texts = first_page_result.get("rec_texts", [])
|
|
63
|
+
scores = first_page_result.get("rec_scores", [])
|
|
64
|
+
for i, text in enumerate(texts):
|
|
65
|
+
extracted_text += str(text) + "\n"
|
|
66
|
+
if i < len(scores):
|
|
67
|
+
extracted_scores.append(scores[i])
|
|
68
|
+
elif isinstance(first_page_result, list):
|
|
69
|
+
for line in first_page_result:
|
|
70
|
+
try:
|
|
71
|
+
if isinstance(line, (list, tuple)) and len(line) >= 2:
|
|
72
|
+
if isinstance(line[1], (list, tuple)) and len(line[1]) >= 2:
|
|
73
|
+
extracted_text += str(line[1][0]) + "\n"
|
|
74
|
+
extracted_scores.append(line[1][1])
|
|
75
|
+
except Exception:
|
|
76
|
+
continue
|
|
77
|
+
return extracted_text, extracted_scores
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class PaddleEngine(OcrEngine):
|
|
81
|
+
"""Lazy, per-language singleton-style PaddleOCR wrapper.
|
|
82
|
+
|
|
83
|
+
A single instance caches one PaddleOCR model per language code, so models are
|
|
84
|
+
loaded into memory at most once (resource-efficiency requirement).
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
text_source = "ocr"
|
|
88
|
+
|
|
89
|
+
def __init__(self) -> None:
|
|
90
|
+
self._ocr_by_lang: dict[str, object] = {}
|
|
91
|
+
|
|
92
|
+
def _get_ocr(self, paddle_lang: str):
|
|
93
|
+
"""Lazy-load + cache a PaddleOCR model for a language (ported loader)."""
|
|
94
|
+
if paddle_lang in self._ocr_by_lang:
|
|
95
|
+
return self._ocr_by_lang[paddle_lang]
|
|
96
|
+
_ensure_ascii_model_cache()
|
|
97
|
+
_ensure_paddle_runtime_flags()
|
|
98
|
+
try:
|
|
99
|
+
from paddleocr import PaddleOCR
|
|
100
|
+
except ImportError as exc: # pragma: no cover - exercised via install matrix
|
|
101
|
+
raise MissingDependencyError("paddleocr", "paddle") from exc
|
|
102
|
+
|
|
103
|
+
import logging
|
|
104
|
+
|
|
105
|
+
logging.getLogger("ppocr").setLevel(logging.ERROR)
|
|
106
|
+
requested = paddle_lang
|
|
107
|
+
ocr, errors = self._try_init(PaddleOCR, paddle_lang)
|
|
108
|
+
if ocr is None and paddle_lang != "en":
|
|
109
|
+
ocr, en_errors = self._try_init(PaddleOCR, "en")
|
|
110
|
+
errors.extend(en_errors)
|
|
111
|
+
if ocr is None:
|
|
112
|
+
detail = "; ".join(errors[-3:]) if errors else "no profiles attempted"
|
|
113
|
+
raise EngineError(
|
|
114
|
+
f"PaddleOCR could not be initialized for lang={paddle_lang!r}. "
|
|
115
|
+
f"Last errors: {detail}"
|
|
116
|
+
)
|
|
117
|
+
self._ocr_by_lang[requested] = ocr
|
|
118
|
+
return ocr
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _try_init(PaddleOCR, lang: str):
|
|
122
|
+
"""Try several constructor signatures (PaddleOCR 3.x and legacy 2.x).
|
|
123
|
+
|
|
124
|
+
Returns (engine_or_None, [error_strings]). The real exceptions are kept so
|
|
125
|
+
a total failure can be diagnosed instead of silently swallowed.
|
|
126
|
+
"""
|
|
127
|
+
# PaddleOCR 3.x: disable the doc-orientation / unwarping sub-models (unneeded
|
|
128
|
+
# for plain OCR) and oneDNN (CPU PIR incompatibility). Each kwarg is tried and
|
|
129
|
+
# gracefully dropped if a given build rejects it.
|
|
130
|
+
base_3x = {
|
|
131
|
+
"lang": lang,
|
|
132
|
+
"use_doc_orientation_classify": False,
|
|
133
|
+
"use_doc_unwarping": False,
|
|
134
|
+
"use_textline_orientation": False,
|
|
135
|
+
}
|
|
136
|
+
profiles = [
|
|
137
|
+
{**base_3x, "enable_mkldnn": False},
|
|
138
|
+
base_3x,
|
|
139
|
+
{"lang": lang, "enable_mkldnn": False},
|
|
140
|
+
{"lang": lang},
|
|
141
|
+
# Legacy 2.x signature (use_angle_cls / show_log removed in 3.x).
|
|
142
|
+
{"use_angle_cls": True, "use_textline_orientation": True, "lang": lang,
|
|
143
|
+
"show_log": False},
|
|
144
|
+
]
|
|
145
|
+
errors: list[str] = []
|
|
146
|
+
for kwargs in profiles:
|
|
147
|
+
try:
|
|
148
|
+
return PaddleOCR(**kwargs), errors
|
|
149
|
+
except Exception as exc: # noqa: BLE001 - we record and try the next profile
|
|
150
|
+
errors.append(f"{type(exc).__name__}: {exc}")
|
|
151
|
+
return None, errors
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def _run_ocr(ocr_engine, path: str):
|
|
155
|
+
"""Run recognition across PaddleOCR 2.x (.ocr) and 3.x (.predict)."""
|
|
156
|
+
predict = getattr(ocr_engine, "predict", None)
|
|
157
|
+
if callable(predict):
|
|
158
|
+
try:
|
|
159
|
+
return predict(path)
|
|
160
|
+
except Exception:
|
|
161
|
+
pass
|
|
162
|
+
return ocr_engine.ocr(path)
|
|
163
|
+
|
|
164
|
+
def recognize(
|
|
165
|
+
self,
|
|
166
|
+
img_path: str,
|
|
167
|
+
*,
|
|
168
|
+
lang: str = "en",
|
|
169
|
+
min_lines: int = 3,
|
|
170
|
+
handwriting: bool = False,
|
|
171
|
+
) -> PageOcr:
|
|
172
|
+
langs = candidate_langs(lang)
|
|
173
|
+
preprocessed_paths: list[str] = []
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
ocr_img_path = preprocess_image_for_ocr(img_path, handwriting=handwriting)
|
|
177
|
+
if ocr_img_path != img_path:
|
|
178
|
+
preprocessed_paths.append(ocr_img_path)
|
|
179
|
+
|
|
180
|
+
best_text = ""
|
|
181
|
+
best_scores: list[float] = []
|
|
182
|
+
best_line_count = 0
|
|
183
|
+
|
|
184
|
+
for lang_code in langs:
|
|
185
|
+
ocr_engine = self._get_ocr(lang_code)
|
|
186
|
+
result = self._run_ocr(ocr_engine, ocr_img_path)
|
|
187
|
+
candidate_text, candidate_scores = _extract_from_result(result)
|
|
188
|
+
candidate_line_count = len(
|
|
189
|
+
[ln for ln in candidate_text.splitlines() if ln.strip()]
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Coverage-first selection (confidence ignored):
|
|
193
|
+
# 1) more non-empty lines wins
|
|
194
|
+
# 2) on a tie, longer non-whitespace text wins
|
|
195
|
+
if candidate_line_count > best_line_count or (
|
|
196
|
+
candidate_line_count == best_line_count
|
|
197
|
+
and len(candidate_text.strip()) > len(best_text.strip())
|
|
198
|
+
):
|
|
199
|
+
best_line_count = candidate_line_count
|
|
200
|
+
best_text = candidate_text
|
|
201
|
+
best_scores = candidate_scores
|
|
202
|
+
|
|
203
|
+
text = best_text
|
|
204
|
+
scores = list(best_scores)
|
|
205
|
+
|
|
206
|
+
# If full-image OCR still sees too few lines, run line-band fallback.
|
|
207
|
+
best_line_count = len([ln for ln in best_text.splitlines() if ln.strip()])
|
|
208
|
+
if best_line_count < min_lines:
|
|
209
|
+
text, scores = self._line_band_fallback(
|
|
210
|
+
ocr_img_path, langs, base_text=best_text, base_scores=scores,
|
|
211
|
+
best_line_count=best_line_count,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return PageOcr(text=text.strip(), scores=scores)
|
|
215
|
+
finally:
|
|
216
|
+
cleanup_paths(preprocessed_paths)
|
|
217
|
+
|
|
218
|
+
def _line_band_fallback(
|
|
219
|
+
self,
|
|
220
|
+
ocr_img_path: str,
|
|
221
|
+
langs: list[str],
|
|
222
|
+
*,
|
|
223
|
+
base_text: str,
|
|
224
|
+
base_scores: list[float],
|
|
225
|
+
best_line_count: int,
|
|
226
|
+
) -> tuple[str, list[float]]:
|
|
227
|
+
band_paths = split_image_into_line_bands(ocr_img_path)
|
|
228
|
+
if not band_paths:
|
|
229
|
+
return base_text, base_scores
|
|
230
|
+
|
|
231
|
+
recovered_lines: list[str] = []
|
|
232
|
+
recovered_scores: list[float] = []
|
|
233
|
+
created_paths = [bp for _, bp in band_paths]
|
|
234
|
+
try:
|
|
235
|
+
for _, band_path in sorted(band_paths, key=lambda x: x[0]):
|
|
236
|
+
band_best_text = ""
|
|
237
|
+
band_best_len = 0
|
|
238
|
+
band_best_scores: list[float] = []
|
|
239
|
+
for lang_code in langs:
|
|
240
|
+
ocr_engine = self._get_ocr(lang_code)
|
|
241
|
+
result = self._run_ocr(ocr_engine, band_path)
|
|
242
|
+
txt, sc = _extract_from_result(result)
|
|
243
|
+
txt_len = len(txt.strip())
|
|
244
|
+
if txt_len > band_best_len:
|
|
245
|
+
band_best_len = txt_len
|
|
246
|
+
band_best_text = txt
|
|
247
|
+
band_best_scores = sc
|
|
248
|
+
line = " ".join(
|
|
249
|
+
[p.strip() for p in band_best_text.splitlines() if p.strip()]
|
|
250
|
+
).strip()
|
|
251
|
+
if line:
|
|
252
|
+
recovered_lines.append(line)
|
|
253
|
+
recovered_scores.extend(band_best_scores)
|
|
254
|
+
finally:
|
|
255
|
+
cleanup_paths(created_paths)
|
|
256
|
+
|
|
257
|
+
if len(recovered_lines) > best_line_count:
|
|
258
|
+
text = base_text.rstrip()
|
|
259
|
+
if text and not text.endswith("\n"):
|
|
260
|
+
text += "\n"
|
|
261
|
+
text += "\n".join(recovered_lines) + "\n"
|
|
262
|
+
return text, base_scores + recovered_scores
|
|
263
|
+
|
|
264
|
+
return base_text, base_scores
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Digital PDF text-layer extraction (no GPU / no OCR).
|
|
2
|
+
|
|
3
|
+
Ported verbatim from ocr-service/modal_app.py. Used to skip OCR entirely when a
|
|
4
|
+
PDF already carries an accurate text layer.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
|
|
11
|
+
# PowerPoint / Google Slides PDFs often expose internal image names in the text layer.
|
|
12
|
+
_PDF_IMAGE_ARTIFACT_RE = re.compile(
|
|
13
|
+
r"^[\w.\-]{1,120}\.(?:png|jpe?g|gif|webp|bmp|tiff?|svg)$",
|
|
14
|
+
re.IGNORECASE,
|
|
15
|
+
)
|
|
16
|
+
_PDF_KNOWN_ARTIFACTS = frozenset(
|
|
17
|
+
{
|
|
18
|
+
"preencoded.png",
|
|
19
|
+
"image.png",
|
|
20
|
+
"image1.png",
|
|
21
|
+
"image2.png",
|
|
22
|
+
}
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_pdf_text_artifact(line: str) -> bool:
|
|
27
|
+
"""Filter embedded image filenames leaked into PDF text extraction."""
|
|
28
|
+
s = (line or "").strip()
|
|
29
|
+
if not s:
|
|
30
|
+
return False
|
|
31
|
+
lower = s.lower()
|
|
32
|
+
if lower in _PDF_KNOWN_ARTIFACTS:
|
|
33
|
+
return True
|
|
34
|
+
if " " in s or "/" in s or "\\" in s:
|
|
35
|
+
return False
|
|
36
|
+
if _PDF_IMAGE_ARTIFACT_RE.match(s):
|
|
37
|
+
return True
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def extract_pdf_text_preserve_layout(file_bytes: bytes) -> tuple[str, int]:
|
|
42
|
+
"""Extract text from digital PDFs while preserving line order/layout."""
|
|
43
|
+
import fitz
|
|
44
|
+
|
|
45
|
+
pdf_document = fitz.open(stream=file_bytes, filetype="pdf")
|
|
46
|
+
page_count = len(pdf_document)
|
|
47
|
+
pages_output: list[str] = []
|
|
48
|
+
|
|
49
|
+
for page in pdf_document:
|
|
50
|
+
# Use block-level extraction to preserve paragraph breaks and reading order.
|
|
51
|
+
blocks = page.get_text("blocks")
|
|
52
|
+
if not blocks:
|
|
53
|
+
pages_output.append("")
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
# block tuple: (x0, y0, x1, y1, text, block_no, block_type) - block_type 0=text, 1=image
|
|
57
|
+
text_blocks = [
|
|
58
|
+
b
|
|
59
|
+
for b in blocks
|
|
60
|
+
if len(b) >= 5
|
|
61
|
+
and (len(b) < 7 or b[6] == 0)
|
|
62
|
+
and isinstance(b[4], str)
|
|
63
|
+
and b[4].strip()
|
|
64
|
+
]
|
|
65
|
+
text_blocks.sort(key=lambda b: (round(float(b[1]), 1), round(float(b[0]), 1)))
|
|
66
|
+
|
|
67
|
+
if not text_blocks:
|
|
68
|
+
pages_output.append("")
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
page_lines: list[str] = []
|
|
72
|
+
prev_bottom = None
|
|
73
|
+
|
|
74
|
+
for block in text_blocks:
|
|
75
|
+
y0, y1 = float(block[1]), float(block[3])
|
|
76
|
+
block_text = block[4].replace("\r\n", "\n").replace("\r", "\n").strip()
|
|
77
|
+
if not block_text:
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
# Insert paragraph gap if there is visible vertical space between blocks.
|
|
81
|
+
if prev_bottom is not None and (y0 - prev_bottom) > 8:
|
|
82
|
+
if page_lines and page_lines[-1] != "":
|
|
83
|
+
page_lines.append("")
|
|
84
|
+
|
|
85
|
+
block_lines = [
|
|
86
|
+
ln.rstrip()
|
|
87
|
+
for ln in block_text.split("\n")
|
|
88
|
+
if ln.strip() and not is_pdf_text_artifact(ln)
|
|
89
|
+
]
|
|
90
|
+
page_lines.extend(block_lines)
|
|
91
|
+
prev_bottom = y1
|
|
92
|
+
|
|
93
|
+
# Collapse accidental triple+ gaps while keeping intentional paragraph breaks.
|
|
94
|
+
compact_lines: list[str] = []
|
|
95
|
+
empty_streak = 0
|
|
96
|
+
for ln in page_lines:
|
|
97
|
+
if ln.strip() == "":
|
|
98
|
+
empty_streak += 1
|
|
99
|
+
if empty_streak <= 1:
|
|
100
|
+
compact_lines.append("")
|
|
101
|
+
else:
|
|
102
|
+
empty_streak = 0
|
|
103
|
+
compact_lines.append(ln)
|
|
104
|
+
|
|
105
|
+
pages_output.append("\n".join(compact_lines).strip())
|
|
106
|
+
|
|
107
|
+
pdf_document.close()
|
|
108
|
+
|
|
109
|
+
full_text = ""
|
|
110
|
+
for idx, page_text in enumerate(pages_output):
|
|
111
|
+
if idx > 0:
|
|
112
|
+
full_text += f"\n\n--- Page {idx + 1} ---\n\n"
|
|
113
|
+
full_text += page_text
|
|
114
|
+
|
|
115
|
+
return full_text, page_count
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def has_sufficient_pdf_text(text: str) -> bool:
|
|
119
|
+
"""True when the PDF text layer is rich enough to use instead of OCR."""
|
|
120
|
+
stripped = (text or "").strip()
|
|
121
|
+
if len(stripped) < 80:
|
|
122
|
+
return False
|
|
123
|
+
|
|
124
|
+
alnum_count = sum(ch.isalnum() for ch in stripped)
|
|
125
|
+
ratio = alnum_count / max(len(stripped), 1)
|
|
126
|
+
return ratio >= 0.25
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Singleton registry for heavy OCR engines.
|
|
2
|
+
|
|
3
|
+
PaddleOCR and TrOCR models are expensive to load. The registry guarantees each
|
|
4
|
+
engine (and therefore each model) is instantiated at most once per process,
|
|
5
|
+
satisfying the resource-efficiency requirement. Loading is lazy: an engine is
|
|
6
|
+
only created the first time it is requested.
|
|
7
|
+
|
|
8
|
+
Thread-safe so the same singleton is shared across threads (e.g. a web worker
|
|
9
|
+
pool that wraps this library).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import threading
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from .handwriting import HandwritingEngine
|
|
19
|
+
from .paddle import PaddleEngine
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class EngineRegistry:
|
|
23
|
+
"""Process-wide lazy cache of OCR engines.
|
|
24
|
+
|
|
25
|
+
A default shared instance is exposed via :meth:`shared`, but callers may also
|
|
26
|
+
create isolated registries (useful for tests).
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
_shared: "EngineRegistry | None" = None
|
|
30
|
+
_shared_lock = threading.Lock()
|
|
31
|
+
|
|
32
|
+
def __init__(self) -> None:
|
|
33
|
+
self._lock = threading.Lock()
|
|
34
|
+
self._paddle: "PaddleEngine | None" = None
|
|
35
|
+
self._handwriting: "HandwritingEngine | None" = None
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def shared(cls) -> "EngineRegistry":
|
|
39
|
+
if cls._shared is None:
|
|
40
|
+
with cls._shared_lock:
|
|
41
|
+
if cls._shared is None:
|
|
42
|
+
cls._shared = cls()
|
|
43
|
+
return cls._shared
|
|
44
|
+
|
|
45
|
+
def paddle(self) -> "PaddleEngine":
|
|
46
|
+
if self._paddle is None:
|
|
47
|
+
with self._lock:
|
|
48
|
+
if self._paddle is None:
|
|
49
|
+
from .paddle import PaddleEngine
|
|
50
|
+
|
|
51
|
+
self._paddle = PaddleEngine()
|
|
52
|
+
return self._paddle
|
|
53
|
+
|
|
54
|
+
def handwriting(self) -> "HandwritingEngine":
|
|
55
|
+
if self._handwriting is None:
|
|
56
|
+
with self._lock:
|
|
57
|
+
if self._handwriting is None:
|
|
58
|
+
from .handwriting import HandwritingEngine
|
|
59
|
+
|
|
60
|
+
self._handwriting = HandwritingEngine()
|
|
61
|
+
return self._handwriting
|
|
62
|
+
|
|
63
|
+
def reset(self) -> None:
|
|
64
|
+
"""Drop cached engines (frees model memory). Mainly for tests."""
|
|
65
|
+
with self._lock:
|
|
66
|
+
self._paddle = None
|
|
67
|
+
self._handwriting = None
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Microsoft TrOCR handwriting engine (line-by-line).
|
|
2
|
+
|
|
3
|
+
Ported verbatim from ocr-service/handwriting_ocr.py. Used as the fallback when
|
|
4
|
+
Google Vision is unavailable or returns nothing. ``transformers``/``torch`` are
|
|
5
|
+
imported lazily (install the ``trocr`` extra).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from ..exceptions import MissingDependencyError
|
|
11
|
+
from ..utils.files import cleanup_paths, new_temp_path
|
|
12
|
+
|
|
13
|
+
TROCR_MODEL_ID = "microsoft/trocr-base-handwritten"
|
|
14
|
+
MIN_BAND_HEIGHT = 12
|
|
15
|
+
MAX_NEW_TOKENS = 128
|
|
16
|
+
TARGET_HEIGHT = 384
|
|
17
|
+
MAX_WIDTH = 1280
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def split_image_into_line_bands(img_path: str) -> list[tuple[int, str]]:
|
|
21
|
+
"""Horizontal projection -> line crops for TrOCR (one line per image)."""
|
|
22
|
+
import cv2
|
|
23
|
+
|
|
24
|
+
bands: list[tuple[int, str]] = []
|
|
25
|
+
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
|
|
26
|
+
if img is None:
|
|
27
|
+
return bands
|
|
28
|
+
|
|
29
|
+
h, w = img.shape[:2]
|
|
30
|
+
if h < 80 or w < 80:
|
|
31
|
+
return bands
|
|
32
|
+
|
|
33
|
+
blur = cv2.GaussianBlur(img, (3, 3), 0)
|
|
34
|
+
bw = cv2.adaptiveThreshold(
|
|
35
|
+
blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 31, 12
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
import numpy as np
|
|
39
|
+
|
|
40
|
+
row_sum = np.sum(bw > 0, axis=1)
|
|
41
|
+
threshold = max(8, int(0.02 * w))
|
|
42
|
+
active_rows = row_sum > threshold
|
|
43
|
+
|
|
44
|
+
segments: list[tuple[int, int]] = []
|
|
45
|
+
start = None
|
|
46
|
+
for i, active in enumerate(active_rows):
|
|
47
|
+
if active and start is None:
|
|
48
|
+
start = i
|
|
49
|
+
elif not active and start is not None:
|
|
50
|
+
if i - start >= 10:
|
|
51
|
+
segments.append((start, i))
|
|
52
|
+
start = None
|
|
53
|
+
if start is not None and (len(active_rows) - start) >= 10:
|
|
54
|
+
segments.append((start, len(active_rows)))
|
|
55
|
+
|
|
56
|
+
if not segments:
|
|
57
|
+
return bands
|
|
58
|
+
|
|
59
|
+
merged: list[list[int]] = []
|
|
60
|
+
for s, e in segments:
|
|
61
|
+
if not merged:
|
|
62
|
+
merged.append([s, e])
|
|
63
|
+
elif s - merged[-1][1] <= 12:
|
|
64
|
+
merged[-1][1] = e
|
|
65
|
+
else:
|
|
66
|
+
merged.append([s, e])
|
|
67
|
+
|
|
68
|
+
for idx_band, (s, e) in enumerate(merged):
|
|
69
|
+
pad = 12
|
|
70
|
+
y0 = max(0, s - pad)
|
|
71
|
+
y1 = min(h, e + pad)
|
|
72
|
+
crop = img[y0:y1, :]
|
|
73
|
+
if crop.shape[0] < MIN_BAND_HEIGHT:
|
|
74
|
+
continue
|
|
75
|
+
upscaled = cv2.resize(crop, None, fx=2.5, fy=2.5, interpolation=cv2.INTER_CUBIC)
|
|
76
|
+
band_path = new_temp_path("png")
|
|
77
|
+
cv2.imwrite(band_path, upscaled)
|
|
78
|
+
bands.append((y0, band_path))
|
|
79
|
+
|
|
80
|
+
return bands
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def prepare_image_for_trocr(image):
|
|
84
|
+
"""Resize to TrOCR-friendly dimensions (avoids ViT tensor errors on tiny/huge crops)."""
|
|
85
|
+
from PIL import Image
|
|
86
|
+
|
|
87
|
+
image = image.convert("RGB")
|
|
88
|
+
w, h = image.size
|
|
89
|
+
if h < 1 or w < 1:
|
|
90
|
+
return image
|
|
91
|
+
|
|
92
|
+
if h < 32 or w < 32:
|
|
93
|
+
scale = max(32 / w, 32 / h)
|
|
94
|
+
w, h = max(32, int(w * scale)), max(32, int(h * scale))
|
|
95
|
+
image = image.resize((w, h), Image.Resampling.LANCZOS)
|
|
96
|
+
|
|
97
|
+
if h != TARGET_HEIGHT:
|
|
98
|
+
new_w = max(32, int(w * (TARGET_HEIGHT / h)))
|
|
99
|
+
image = image.resize((new_w, TARGET_HEIGHT), Image.Resampling.LANCZOS)
|
|
100
|
+
w, h = image.size
|
|
101
|
+
|
|
102
|
+
if w > MAX_WIDTH:
|
|
103
|
+
image = image.resize((MAX_WIDTH, int(h * MAX_WIDTH / w)), Image.Resampling.LANCZOS)
|
|
104
|
+
|
|
105
|
+
return image
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class TrOCRHandwritingEngine:
|
|
109
|
+
def __init__(self, model_id: str = TROCR_MODEL_ID) -> None:
|
|
110
|
+
self.model_id = model_id
|
|
111
|
+
self._processor = None
|
|
112
|
+
self._model = None
|
|
113
|
+
self._device = None
|
|
114
|
+
|
|
115
|
+
def load(self) -> None:
|
|
116
|
+
try:
|
|
117
|
+
import torch
|
|
118
|
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
119
|
+
except ImportError as exc: # pragma: no cover - exercised via install matrix
|
|
120
|
+
raise MissingDependencyError("transformers", "trocr") from exc
|
|
121
|
+
|
|
122
|
+
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
123
|
+
self._processor = TrOCRProcessor.from_pretrained(self.model_id)
|
|
124
|
+
self._model = VisionEncoderDecoderModel.from_pretrained(self.model_id)
|
|
125
|
+
self._model.to(self._device)
|
|
126
|
+
self._model.eval()
|
|
127
|
+
|
|
128
|
+
def warmup_inference(self) -> None:
|
|
129
|
+
if self._processor is None or self._model is None:
|
|
130
|
+
return
|
|
131
|
+
from PIL import Image
|
|
132
|
+
|
|
133
|
+
dummy = Image.new("RGB", (384, 96), color=(255, 255, 255))
|
|
134
|
+
try:
|
|
135
|
+
_ = self.recognize_pil(dummy)
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
def recognize_line_image_path(self, path: str) -> str:
|
|
140
|
+
from PIL import Image
|
|
141
|
+
|
|
142
|
+
image = Image.open(path).convert("RGB")
|
|
143
|
+
return self.recognize_pil(image)
|
|
144
|
+
|
|
145
|
+
def recognize_pil(self, image) -> str:
|
|
146
|
+
import torch
|
|
147
|
+
|
|
148
|
+
if self._processor is None or self._model is None:
|
|
149
|
+
raise RuntimeError("TrOCRHandwritingEngine.load() was not called")
|
|
150
|
+
|
|
151
|
+
image = prepare_image_for_trocr(image)
|
|
152
|
+
# Positional call matches HF docs; avoids kwarg edge cases in older processors.
|
|
153
|
+
pixel_values = self._processor(image, return_tensors="pt").pixel_values
|
|
154
|
+
pixel_values = pixel_values.to(self._device)
|
|
155
|
+
|
|
156
|
+
with torch.no_grad():
|
|
157
|
+
generated_ids = self._model.generate(
|
|
158
|
+
pixel_values,
|
|
159
|
+
max_new_tokens=MAX_NEW_TOKENS,
|
|
160
|
+
num_beams=4,
|
|
161
|
+
early_stopping=True,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
text = self._processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
165
|
+
return (text or "").strip()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def run_trocr_on_page(engine: TrOCRHandwritingEngine, img_path: str) -> tuple[str, float]:
|
|
169
|
+
"""OCR one page image with TrOCR line bands. Returns (text, pseudo_confidence 0..1)."""
|
|
170
|
+
bands = split_image_into_line_bands(img_path)
|
|
171
|
+
created: list[str] = []
|
|
172
|
+
lines: list[str] = []
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
if not bands:
|
|
176
|
+
text = engine.recognize_line_image_path(img_path)
|
|
177
|
+
if text:
|
|
178
|
+
lines.append(text)
|
|
179
|
+
else:
|
|
180
|
+
for _, band_path in sorted(bands, key=lambda x: x[0]):
|
|
181
|
+
created.append(band_path)
|
|
182
|
+
line = engine.recognize_line_image_path(band_path)
|
|
183
|
+
line = " ".join(line.split())
|
|
184
|
+
if line:
|
|
185
|
+
lines.append(line)
|
|
186
|
+
finally:
|
|
187
|
+
cleanup_paths(created)
|
|
188
|
+
|
|
189
|
+
full = "\n".join(lines).strip()
|
|
190
|
+
conf = min(1.0, len(full) / 200.0) if full else 0.0
|
|
191
|
+
return full, conf
|