ocrcontext 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocrcontext/__init__.py +49 -0
- ocrcontext/analyzer.py +198 -0
- ocrcontext/config.py +49 -0
- ocrcontext/engines/__init__.py +6 -0
- ocrcontext/engines/base.py +45 -0
- ocrcontext/engines/handwriting.py +103 -0
- ocrcontext/engines/paddle.py +264 -0
- ocrcontext/engines/pdf_text.py +126 -0
- ocrcontext/engines/registry.py +67 -0
- ocrcontext/engines/trocr.py +191 -0
- ocrcontext/engines/vision.py +538 -0
- ocrcontext/exceptions.py +45 -0
- ocrcontext/llm/__init__.py +10 -0
- ocrcontext/llm/drift.py +58 -0
- ocrcontext/llm/extractor.py +63 -0
- ocrcontext/llm/formatting.py +39 -0
- ocrcontext/llm/literal_preserve.py +164 -0
- ocrcontext/llm/prompts.py +157 -0
- ocrcontext/llm/refiner.py +114 -0
- ocrcontext/llm/schemas.py +99 -0
- ocrcontext/pipeline.py +162 -0
- ocrcontext/preprocessing/__init__.py +5 -0
- ocrcontext/preprocessing/image.py +177 -0
- ocrcontext/py.typed +0 -0
- ocrcontext/quality.py +76 -0
- ocrcontext/schemas.py +8 -0
- ocrcontext/types.py +55 -0
- ocrcontext/utils/__init__.py +1 -0
- ocrcontext/utils/files.py +172 -0
- ocrcontext/utils/lang.py +77 -0
- ocrcontext-0.1.0.dist-info/METADATA +207 -0
- ocrcontext-0.1.0.dist-info/RECORD +34 -0
- ocrcontext-0.1.0.dist-info/WHEEL +4 -0
- ocrcontext-0.1.0.dist-info/licenses/LICENSE +21 -0
ocrcontext/pipeline.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Document routing + OCR orchestration (no LLM here — that's the analyzer's job).
|
|
2
|
+
|
|
3
|
+
Reproduces the retry/fallback ladder from app/api/documents/process/route.ts and
|
|
4
|
+
the page loops in OCRService / HandwritingOCRService, minus the web/Modal layer.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from .config import AnalyzerConfig
|
|
10
|
+
from .engines.base import OcrEngine, PageOcr
|
|
11
|
+
from .engines.pdf_text import extract_pdf_text_preserve_layout, has_sufficient_pdf_text
|
|
12
|
+
from .engines.registry import EngineRegistry
|
|
13
|
+
from .quality import is_ocr_text_insufficient
|
|
14
|
+
from .types import OcrResult, TextSource
|
|
15
|
+
from .utils.files import (
|
|
16
|
+
Source,
|
|
17
|
+
cleanup_paths,
|
|
18
|
+
is_pdf,
|
|
19
|
+
load_source,
|
|
20
|
+
new_temp_path,
|
|
21
|
+
rasterize_pdf,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Pipeline:
|
|
26
|
+
"""Routes a document to the right engine(s) and returns raw OCR text."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
registry: EngineRegistry | None = None,
|
|
31
|
+
config: AnalyzerConfig | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
self.registry = registry or EngineRegistry.shared()
|
|
34
|
+
self.config = config or AnalyzerConfig()
|
|
35
|
+
|
|
36
|
+
def run(
|
|
37
|
+
self,
|
|
38
|
+
source: Source,
|
|
39
|
+
*,
|
|
40
|
+
lang: str | None = None,
|
|
41
|
+
handwriting: bool = False,
|
|
42
|
+
filename: str | None = None,
|
|
43
|
+
) -> OcrResult:
|
|
44
|
+
lang = lang or self.config.lang
|
|
45
|
+
file_bytes, ext = load_source(source, filename=filename)
|
|
46
|
+
|
|
47
|
+
# 1) Digital PDF text layer — exact text, no OCR / no GPU.
|
|
48
|
+
if is_pdf(ext) and self.config.prefer_pdf_text_layer and not handwriting:
|
|
49
|
+
full_text, page_count = extract_pdf_text_preserve_layout(file_bytes)
|
|
50
|
+
if has_sufficient_pdf_text(full_text):
|
|
51
|
+
return OcrResult(
|
|
52
|
+
text=full_text.strip(),
|
|
53
|
+
confidence=1.0,
|
|
54
|
+
pages=page_count,
|
|
55
|
+
text_source="pdf_text_layer",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# 2) Printed / scanned OCR (or handwriting if explicitly requested).
|
|
59
|
+
result = self._ocr(file_bytes, ext, lang=lang, handwriting=handwriting)
|
|
60
|
+
|
|
61
|
+
# 3) Auto handwriting fallback when printed OCR returns too little text.
|
|
62
|
+
if (
|
|
63
|
+
not handwriting
|
|
64
|
+
and self.config.auto_handwriting_fallback
|
|
65
|
+
and is_ocr_text_insufficient(result.text, result.pages)
|
|
66
|
+
):
|
|
67
|
+
result = self._ocr(file_bytes, ext, lang=lang, handwriting=True)
|
|
68
|
+
|
|
69
|
+
return result
|
|
70
|
+
|
|
71
|
+
def _ocr(
|
|
72
|
+
self, file_bytes: bytes, ext: str, *, lang: str, handwriting: bool
|
|
73
|
+
) -> OcrResult:
|
|
74
|
+
engine = (
|
|
75
|
+
self.registry.handwriting() if handwriting else self.registry.paddle()
|
|
76
|
+
)
|
|
77
|
+
render_scale = (
|
|
78
|
+
self.config.pdf_render_scale_handwriting
|
|
79
|
+
if handwriting
|
|
80
|
+
else self.config.pdf_render_scale
|
|
81
|
+
)
|
|
82
|
+
min_lines = (
|
|
83
|
+
self.config.min_lines_handwriting
|
|
84
|
+
if handwriting
|
|
85
|
+
else self.config.min_lines_per_page
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
image_paths, owned = self._materialize_images(file_bytes, ext, render_scale)
|
|
89
|
+
try:
|
|
90
|
+
return self._ocr_pages(
|
|
91
|
+
engine,
|
|
92
|
+
image_paths,
|
|
93
|
+
lang=lang,
|
|
94
|
+
handwriting=handwriting,
|
|
95
|
+
min_lines=min_lines,
|
|
96
|
+
)
|
|
97
|
+
finally:
|
|
98
|
+
cleanup_paths(owned)
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def _materialize_images(
|
|
102
|
+
file_bytes: bytes, ext: str, render_scale: float
|
|
103
|
+
) -> tuple[list[str], list[str]]:
|
|
104
|
+
"""Return (image_paths, owned_paths_to_cleanup)."""
|
|
105
|
+
if is_pdf(ext):
|
|
106
|
+
paths = rasterize_pdf(file_bytes, render_scale)
|
|
107
|
+
return paths, list(paths)
|
|
108
|
+
# Single image: write to a temp file the engines can read.
|
|
109
|
+
path = new_temp_path(ext)
|
|
110
|
+
with open(path, "wb") as f:
|
|
111
|
+
f.write(file_bytes)
|
|
112
|
+
return [path], [path]
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def _ocr_pages(
|
|
116
|
+
engine: OcrEngine,
|
|
117
|
+
image_paths: list[str],
|
|
118
|
+
*,
|
|
119
|
+
lang: str,
|
|
120
|
+
handwriting: bool,
|
|
121
|
+
min_lines: int,
|
|
122
|
+
) -> OcrResult:
|
|
123
|
+
full_text = ""
|
|
124
|
+
all_scores: list[float] = []
|
|
125
|
+
used_vision = False
|
|
126
|
+
used_trocr = False
|
|
127
|
+
has_dikw = False
|
|
128
|
+
|
|
129
|
+
for idx, img_path in enumerate(image_paths):
|
|
130
|
+
if idx > 0:
|
|
131
|
+
full_text += f"\n\n--- Page {idx + 1} ---\n\n"
|
|
132
|
+
|
|
133
|
+
page: PageOcr = engine.recognize(
|
|
134
|
+
img_path, lang=lang, min_lines=min_lines, handwriting=handwriting
|
|
135
|
+
)
|
|
136
|
+
full_text += page.text + ("\n" if (handwriting and page.text) else "")
|
|
137
|
+
all_scores.extend(page.scores)
|
|
138
|
+
if page.has_dikw_structure:
|
|
139
|
+
has_dikw = True
|
|
140
|
+
if page.text_source == "vision_handwriting":
|
|
141
|
+
used_vision = True
|
|
142
|
+
elif page.text_source == "trocr_handwriting":
|
|
143
|
+
used_trocr = True
|
|
144
|
+
|
|
145
|
+
avg_conf = sum(all_scores) / len(all_scores) if all_scores else 0.0
|
|
146
|
+
|
|
147
|
+
text_source: TextSource = engine.text_source # type: ignore[assignment]
|
|
148
|
+
if handwriting:
|
|
149
|
+
if used_vision and not used_trocr:
|
|
150
|
+
text_source = "vision_handwriting"
|
|
151
|
+
elif used_trocr:
|
|
152
|
+
text_source = "trocr_handwriting"
|
|
153
|
+
else:
|
|
154
|
+
text_source = "handwriting_ocr"
|
|
155
|
+
|
|
156
|
+
return OcrResult(
|
|
157
|
+
text=full_text.strip(),
|
|
158
|
+
confidence=round(avg_conf, 4),
|
|
159
|
+
pages=len(image_paths),
|
|
160
|
+
text_source=text_source,
|
|
161
|
+
has_dikw_structure=has_dikw,
|
|
162
|
+
)
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Image preprocessing, ported verbatim from ocr-service/modal_app.py.
|
|
2
|
+
|
|
3
|
+
OpenCV / NumPy are imported lazily so the base install does not require the
|
|
4
|
+
``paddle`` extra unless an image OCR path is actually taken.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from ..exceptions import MissingDependencyError
|
|
10
|
+
from ..utils.files import new_temp_path
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _require_cv2():
|
|
14
|
+
"""Import OpenCV, raising a friendly error if the image extras aren't installed."""
|
|
15
|
+
try:
|
|
16
|
+
import cv2
|
|
17
|
+
|
|
18
|
+
return cv2
|
|
19
|
+
except ImportError as exc: # pragma: no cover - exercised via install matrix
|
|
20
|
+
raise MissingDependencyError("opencv-python-headless", "paddle") from exc
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _deskew_grayscale(img):
|
|
24
|
+
"""Correct slight rotation on notebook photos."""
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
cv2 = _require_cv2()
|
|
28
|
+
|
|
29
|
+
inv = cv2.bitwise_not(img)
|
|
30
|
+
coords = np.column_stack(np.where(inv > 0))
|
|
31
|
+
if len(coords) < 200:
|
|
32
|
+
return img
|
|
33
|
+
|
|
34
|
+
angle = cv2.minAreaRect(coords)[-1]
|
|
35
|
+
if angle < -45:
|
|
36
|
+
angle = -(90 + angle)
|
|
37
|
+
else:
|
|
38
|
+
angle = -angle
|
|
39
|
+
|
|
40
|
+
if abs(angle) < 0.4 or abs(angle) > 12:
|
|
41
|
+
return img
|
|
42
|
+
|
|
43
|
+
h, w = img.shape[:2]
|
|
44
|
+
center = (w // 2, h // 2)
|
|
45
|
+
matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
|
|
46
|
+
return cv2.warpAffine(
|
|
47
|
+
img,
|
|
48
|
+
matrix,
|
|
49
|
+
(w, h),
|
|
50
|
+
flags=cv2.INTER_CUBIC,
|
|
51
|
+
borderMode=cv2.BORDER_REPLICATE,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _suppress_notebook_ruling(img):
|
|
56
|
+
"""Reduce horizontal ruled lines that cross handwritten strokes."""
|
|
57
|
+
cv2 = _require_cv2()
|
|
58
|
+
|
|
59
|
+
h, w = img.shape[:2]
|
|
60
|
+
if h < 80 or w < 80:
|
|
61
|
+
return img
|
|
62
|
+
|
|
63
|
+
binary = cv2.adaptiveThreshold(
|
|
64
|
+
img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 15, 8
|
|
65
|
+
)
|
|
66
|
+
line_w = max(25, w // 25)
|
|
67
|
+
horiz_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (line_w, 1))
|
|
68
|
+
horizontal = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horiz_kernel, iterations=1)
|
|
69
|
+
horizontal = cv2.dilate(horizontal, horiz_kernel, iterations=1)
|
|
70
|
+
|
|
71
|
+
if cv2.countNonZero(horizontal) < 50:
|
|
72
|
+
return img
|
|
73
|
+
|
|
74
|
+
cleaned = cv2.inpaint(img, horizontal, 2, cv2.INPAINT_TELEA)
|
|
75
|
+
return cleaned
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def preprocess_image_for_ocr(img_path: str, handwriting: bool = False) -> str:
|
|
79
|
+
"""Contrast + denoise before OCR (helps faint / handwritten scans).
|
|
80
|
+
|
|
81
|
+
Returns a path to a preprocessed PNG, or the original path if the image
|
|
82
|
+
could not be read. Caller is responsible for cleaning up new files.
|
|
83
|
+
"""
|
|
84
|
+
cv2 = _require_cv2()
|
|
85
|
+
|
|
86
|
+
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
|
|
87
|
+
if img is None:
|
|
88
|
+
return img_path
|
|
89
|
+
|
|
90
|
+
if handwriting:
|
|
91
|
+
h, w = img.shape[:2]
|
|
92
|
+
# Notebook margin rulers (e.g. 23, 22, 21...) confuse Vision - trim narrow left strip.
|
|
93
|
+
if w > h * 0.9 and w > 400:
|
|
94
|
+
crop_x = min(int(w * 0.08), 100)
|
|
95
|
+
img = img[:, crop_x:]
|
|
96
|
+
# Deskew only; skip ruled-line inpainting - it can erase strokes that touch notebook lines.
|
|
97
|
+
img = _deskew_grayscale(img)
|
|
98
|
+
img = cv2.fastNlMeansDenoising(img, None, h=8, templateWindowSize=7, searchWindowSize=21)
|
|
99
|
+
clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))
|
|
100
|
+
img = clahe.apply(img)
|
|
101
|
+
# Slight upscale for thin strokes
|
|
102
|
+
img = cv2.resize(img, None, fx=1.15, fy=1.15, interpolation=cv2.INTER_CUBIC)
|
|
103
|
+
else:
|
|
104
|
+
clahe = cv2.createCLAHE(clipLimit=1.8, tileGridSize=(8, 8))
|
|
105
|
+
img = clahe.apply(img)
|
|
106
|
+
|
|
107
|
+
out_path = new_temp_path("png")
|
|
108
|
+
cv2.imwrite(out_path, img)
|
|
109
|
+
return out_path
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def split_image_into_line_bands(img_path: str) -> list[tuple[int, str]]:
|
|
113
|
+
"""Create horizontal candidate bands for line recovery via projection profile.
|
|
114
|
+
|
|
115
|
+
Returns ``[(y_offset, band_image_path), ...]``. Caller owns cleanup of the
|
|
116
|
+
band image files. Ported from OCRService.process.split_image_into_line_bands.
|
|
117
|
+
"""
|
|
118
|
+
import numpy as np
|
|
119
|
+
|
|
120
|
+
cv2 = _require_cv2()
|
|
121
|
+
|
|
122
|
+
bands: list[tuple[int, str]] = []
|
|
123
|
+
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
|
|
124
|
+
if img is None:
|
|
125
|
+
return bands
|
|
126
|
+
h, w = img.shape[:2]
|
|
127
|
+
if h < 80 or w < 80:
|
|
128
|
+
return bands
|
|
129
|
+
|
|
130
|
+
blur = cv2.GaussianBlur(img, (3, 3), 0)
|
|
131
|
+
bw = cv2.adaptiveThreshold(
|
|
132
|
+
blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 31, 12
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Horizontal projection profile
|
|
136
|
+
row_sum = np.sum(bw > 0, axis=1)
|
|
137
|
+
threshold = max(8, int(0.02 * w))
|
|
138
|
+
active_rows = row_sum > threshold
|
|
139
|
+
|
|
140
|
+
segments: list[tuple[int, int]] = []
|
|
141
|
+
start = None
|
|
142
|
+
for i, active in enumerate(active_rows):
|
|
143
|
+
if active and start is None:
|
|
144
|
+
start = i
|
|
145
|
+
elif not active and start is not None:
|
|
146
|
+
if i - start >= 10:
|
|
147
|
+
segments.append((start, i))
|
|
148
|
+
start = None
|
|
149
|
+
if start is not None and (len(active_rows) - start) >= 10:
|
|
150
|
+
segments.append((start, len(active_rows)))
|
|
151
|
+
|
|
152
|
+
if not segments:
|
|
153
|
+
return bands
|
|
154
|
+
|
|
155
|
+
# Merge close segments (same line broken by weak strokes)
|
|
156
|
+
merged: list[list[int]] = []
|
|
157
|
+
for s, e in segments:
|
|
158
|
+
if not merged:
|
|
159
|
+
merged.append([s, e])
|
|
160
|
+
elif s - merged[-1][1] <= 8:
|
|
161
|
+
merged[-1][1] = e
|
|
162
|
+
else:
|
|
163
|
+
merged.append([s, e])
|
|
164
|
+
|
|
165
|
+
for idx_band, (s, e) in enumerate(merged):
|
|
166
|
+
pad = 10
|
|
167
|
+
y0 = max(0, s - pad)
|
|
168
|
+
y1 = min(h, e + pad)
|
|
169
|
+
crop = img[y0:y1, :]
|
|
170
|
+
if crop.shape[0] < 12:
|
|
171
|
+
continue
|
|
172
|
+
upscaled = cv2.resize(crop, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
|
173
|
+
band_path = new_temp_path("png")
|
|
174
|
+
cv2.imwrite(band_path, upscaled)
|
|
175
|
+
bands.append((y0, band_path))
|
|
176
|
+
|
|
177
|
+
return bands
|
ocrcontext/py.typed
ADDED
|
File without changes
|
ocrcontext/quality.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""OCR text-quality heuristics, ported from lib/ocr/ocr-quality.ts and
|
|
2
|
+
lib/ocr/handwriting-refine.ts / detect-dikw.ts.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
|
|
9
|
+
from .types import RefinementMode
|
|
10
|
+
|
|
11
|
+
_ALNUM_RE = re.compile(r"[^\W_]", re.UNICODE) # unicode letters/digits
|
|
12
|
+
_WS_RE = re.compile(r"\s+")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def is_ocr_text_insufficient(text: str, page_count: int = 1) -> bool:
|
|
16
|
+
"""Heuristic: OCR returned too little usable text (common on handwriting)."""
|
|
17
|
+
stripped = _WS_RE.sub(" ", text).strip()
|
|
18
|
+
if not stripped:
|
|
19
|
+
return True
|
|
20
|
+
|
|
21
|
+
pages = max(1, page_count)
|
|
22
|
+
min_chars = max(50, pages * 25)
|
|
23
|
+
if len(stripped) < min_chars:
|
|
24
|
+
return True
|
|
25
|
+
|
|
26
|
+
alnum = len(_ALNUM_RE.findall(stripped))
|
|
27
|
+
ratio = alnum / max(len(stripped), 1)
|
|
28
|
+
return ratio < 0.2
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# --- DIKW detection (ported from lib/ocr/detect-dikw.ts) -----------------------
|
|
32
|
+
|
|
33
|
+
_DIKW_LETTERS = {"W", "K", "I", "D"}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _is_dikw_letter_token(token: str) -> bool:
|
|
37
|
+
t = re.sub(r"[^a-zA-Z]", "", token)
|
|
38
|
+
return len(t) == 1 and t.upper() in _DIKW_LETTERS
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _is_pyramid_header(line: str) -> bool:
|
|
42
|
+
low = line.lower()
|
|
43
|
+
return "piramid" in low or "pyramid" in low or "dikw" in low
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _row_looks_like_dikw_pair(line: str) -> bool:
|
|
47
|
+
tokens = [t for t in re.split(r"[\s·]+", line) if t]
|
|
48
|
+
if len(tokens) < 2:
|
|
49
|
+
return False
|
|
50
|
+
has_letter = any(_is_dikw_letter_token(t) for t in tokens)
|
|
51
|
+
has_long = any(len(re.sub(r"\W", "", t)) > 2 for t in tokens)
|
|
52
|
+
return has_letter and has_long
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def detect_dikw_structure(text: str) -> bool:
|
|
56
|
+
"""True when OCR text looks like a DIKW / pyramid diagram (not plain prose)."""
|
|
57
|
+
lines = [ln.strip() for ln in text.split("\n") if ln.strip()]
|
|
58
|
+
pair_count = 0
|
|
59
|
+
for line in lines:
|
|
60
|
+
if _is_pyramid_header(line):
|
|
61
|
+
return True
|
|
62
|
+
if _row_looks_like_dikw_pair(line):
|
|
63
|
+
pair_count += 1
|
|
64
|
+
if pair_count >= 2:
|
|
65
|
+
return True
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def handwriting_refinement_mode(
|
|
70
|
+
raw_text: str, has_dikw_from_ocr: bool | None = None
|
|
71
|
+
) -> RefinementMode:
|
|
72
|
+
"""Pick the handwriting refinement mode (layout for DIKW, else prose)."""
|
|
73
|
+
has_dikw = has_dikw_from_ocr is True or detect_dikw_structure(raw_text)
|
|
74
|
+
if has_dikw:
|
|
75
|
+
return RefinementMode.HANDWRITING_LAYOUT
|
|
76
|
+
return RefinementMode.HANDWRITING_PROSE
|
ocrcontext/schemas.py
ADDED
ocrcontext/types.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Public result and value types."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Literal, Optional
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RefinementMode(str, Enum):
|
|
12
|
+
"""LLM post-processing modes, ported verbatim from the original pipeline.
|
|
13
|
+
|
|
14
|
+
- ``layout``: digital PDFs — reconstruct clean structure.
|
|
15
|
+
- ``conservative``: printed OCR images/scans — minimal char-level correction.
|
|
16
|
+
- ``handwriting_layout``: handwritten notes/lists/tables/diagrams.
|
|
17
|
+
- ``handwriting_prose``: handwritten poems/paragraphs/letters.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
LAYOUT = "layout"
|
|
21
|
+
CONSERVATIVE = "conservative"
|
|
22
|
+
HANDWRITING_LAYOUT = "handwriting_layout"
|
|
23
|
+
HANDWRITING_PROSE = "handwriting_prose"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# String identifying which engine produced the text. Matches the original
|
|
27
|
+
# `text_source` contract so downstream behaviour (e.g. skip-refine) is preserved.
|
|
28
|
+
TextSource = Literal[
|
|
29
|
+
"pdf_text_layer",
|
|
30
|
+
"ocr",
|
|
31
|
+
"vision_handwriting",
|
|
32
|
+
"trocr_handwriting",
|
|
33
|
+
"handwriting_ocr",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OcrResult(BaseModel):
|
|
38
|
+
"""The output of an OCR / analysis run."""
|
|
39
|
+
|
|
40
|
+
text: str = Field(description="Extracted (and optionally refined) plain text.")
|
|
41
|
+
confidence: float = Field(default=0.0, description="Mean recognition confidence (0..1).")
|
|
42
|
+
pages: int = Field(default=1, description="Number of pages processed.")
|
|
43
|
+
text_source: TextSource = Field(
|
|
44
|
+
default="ocr", description="Which engine produced the text."
|
|
45
|
+
)
|
|
46
|
+
has_dikw_structure: bool = Field(
|
|
47
|
+
default=False, description="True when handwriting looks like a DIKW/pyramid diagram."
|
|
48
|
+
)
|
|
49
|
+
refined: bool = Field(default=False, description="True if an LLM refined the raw OCR text.")
|
|
50
|
+
raw_text: Optional[str] = Field(
|
|
51
|
+
default=None, description="Original OCR text before refinement (when refined)."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def __str__(self) -> str: # convenient: print(result) -> the text
|
|
55
|
+
return self.text
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Internal utilities (language mapping, file handling)."""
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""File / source loading and PDF rasterization helpers.
|
|
2
|
+
|
|
3
|
+
Replaces the Modal service's ``/tmp`` plumbing with cross-platform temp handling
|
|
4
|
+
(uses the OS temp dir via ``tempfile`` so it works on Windows too).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
import tempfile
|
|
12
|
+
import uuid
|
|
13
|
+
from contextlib import contextmanager
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import IO, Iterator, Union
|
|
16
|
+
|
|
17
|
+
from ..exceptions import UnsupportedFileError
|
|
18
|
+
|
|
19
|
+
# What the public API accepts as a document source.
|
|
20
|
+
Source = Union[str, Path, bytes, bytearray, IO[bytes]]
|
|
21
|
+
|
|
22
|
+
IMAGE_EXTS = {"png", "jpg", "jpeg", "bmp", "tif", "tiff", "webp", "gif"}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_ascii(text: str) -> bool:
|
|
26
|
+
return all(ord(ch) < 128 for ch in text)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def short_path(path: str) -> str:
|
|
30
|
+
"""Return the Windows 8.3 short path (ASCII) for an existing path.
|
|
31
|
+
|
|
32
|
+
PaddlePaddle/OpenCV's C++ file readers fail on paths containing non-ASCII
|
|
33
|
+
characters (a common case: a non-ASCII Windows username). The 8.3 short path
|
|
34
|
+
aliases the same file with ASCII-only characters. No-op off Windows or when
|
|
35
|
+
the path already resolves cleanly.
|
|
36
|
+
"""
|
|
37
|
+
if sys.platform != "win32":
|
|
38
|
+
return path
|
|
39
|
+
try:
|
|
40
|
+
import ctypes
|
|
41
|
+
from ctypes import wintypes
|
|
42
|
+
|
|
43
|
+
_GetShortPathNameW = ctypes.windll.kernel32.GetShortPathNameW
|
|
44
|
+
_GetShortPathNameW.argtypes = [wintypes.LPCWSTR, wintypes.LPWSTR, wintypes.DWORD]
|
|
45
|
+
_GetShortPathNameW.restype = wintypes.DWORD
|
|
46
|
+
|
|
47
|
+
buf = ctypes.create_unicode_buffer(560)
|
|
48
|
+
result = _GetShortPathNameW(str(path), buf, len(buf))
|
|
49
|
+
if result:
|
|
50
|
+
return buf.value
|
|
51
|
+
except Exception:
|
|
52
|
+
pass
|
|
53
|
+
return path
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def ascii_safe_dir(path: str) -> str:
|
|
57
|
+
"""ASCII-safe form of an existing directory (8.3 short path on Windows)."""
|
|
58
|
+
if is_ascii(path):
|
|
59
|
+
return path
|
|
60
|
+
candidate = short_path(path)
|
|
61
|
+
return candidate if is_ascii(candidate) else path
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _temp_dir() -> str:
|
|
65
|
+
"""ASCII-safe temp base. Override with ``OCRCONTEXT_TMPDIR``."""
|
|
66
|
+
base = os.environ.get("OCRCONTEXT_TMPDIR") or tempfile.gettempdir()
|
|
67
|
+
return ascii_safe_dir(base)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def new_temp_path(ext: str) -> str:
|
|
71
|
+
"""Absolute, ASCII-safe path to a unique temp file with the given extension."""
|
|
72
|
+
ext = ext.lstrip(".")
|
|
73
|
+
return os.path.join(_temp_dir(), f"ocrctx_{uuid.uuid4().hex}.{ext}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def load_source(source: Source, *, filename: str | None = None) -> tuple[bytes, str]:
|
|
77
|
+
"""Normalize any accepted source into ``(file_bytes, extension)``.
|
|
78
|
+
|
|
79
|
+
``extension`` is a lowercase string without a leading dot (e.g. ``"pdf"``).
|
|
80
|
+
"""
|
|
81
|
+
# Path / path-like string
|
|
82
|
+
if isinstance(source, (str, Path)):
|
|
83
|
+
path = Path(source)
|
|
84
|
+
if not path.exists():
|
|
85
|
+
raise UnsupportedFileError(f"File not found: {path}")
|
|
86
|
+
ext = path.suffix.lower().lstrip(".")
|
|
87
|
+
if not ext:
|
|
88
|
+
raise UnsupportedFileError(f"Cannot determine file type from path: {path}")
|
|
89
|
+
return path.read_bytes(), ext
|
|
90
|
+
|
|
91
|
+
# Raw bytes — need a filename hint for the extension
|
|
92
|
+
if isinstance(source, (bytes, bytearray)):
|
|
93
|
+
ext = _ext_from_filename(filename)
|
|
94
|
+
return bytes(source), ext
|
|
95
|
+
|
|
96
|
+
# File-like object
|
|
97
|
+
if hasattr(source, "read"):
|
|
98
|
+
data = source.read()
|
|
99
|
+
if not isinstance(data, (bytes, bytearray)):
|
|
100
|
+
raise UnsupportedFileError("File-like source must be opened in binary mode.")
|
|
101
|
+
name = filename or getattr(source, "name", None)
|
|
102
|
+
ext = _ext_from_filename(name)
|
|
103
|
+
return bytes(data), ext
|
|
104
|
+
|
|
105
|
+
raise UnsupportedFileError(f"Unsupported source type: {type(source)!r}")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _ext_from_filename(filename: str | None) -> str:
|
|
109
|
+
if not filename or "." not in filename:
|
|
110
|
+
raise UnsupportedFileError(
|
|
111
|
+
"Could not infer file extension. Pass `filename=` when supplying raw bytes."
|
|
112
|
+
)
|
|
113
|
+
return filename.rsplit(".", 1)[-1].lower()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def is_pdf(ext: str) -> bool:
|
|
117
|
+
return ext.lower() == "pdf"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def is_image(ext: str) -> bool:
|
|
121
|
+
return ext.lower() in IMAGE_EXTS
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@contextmanager
|
|
125
|
+
def temp_file(file_bytes: bytes, ext: str) -> Iterator[str]:
|
|
126
|
+
"""Write bytes to a temp file, yield its path, and clean up afterwards."""
|
|
127
|
+
path = new_temp_path(ext)
|
|
128
|
+
with open(path, "wb") as f:
|
|
129
|
+
f.write(file_bytes)
|
|
130
|
+
try:
|
|
131
|
+
yield path
|
|
132
|
+
finally:
|
|
133
|
+
_safe_remove(path)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def rasterize_pdf(file_bytes: bytes, scale: float, prefix: str = "page") -> list[str]:
|
|
137
|
+
"""Render every PDF page to a PNG on disk; return the image paths.
|
|
138
|
+
|
|
139
|
+
Caller owns cleanup (use :func:`cleanup_paths`). Ported from the PDF render
|
|
140
|
+
loop in OCRService.process / HandwritingOCRService.process.
|
|
141
|
+
"""
|
|
142
|
+
import fitz # PyMuPDF — core dependency
|
|
143
|
+
|
|
144
|
+
from PIL import Image
|
|
145
|
+
|
|
146
|
+
image_paths: list[str] = []
|
|
147
|
+
pdf_document = fitz.open(stream=file_bytes, filetype="pdf")
|
|
148
|
+
try:
|
|
149
|
+
for page_num in range(len(pdf_document)):
|
|
150
|
+
page = pdf_document[page_num]
|
|
151
|
+
mat = fitz.Matrix(scale, scale)
|
|
152
|
+
pix = page.get_pixmap(matrix=mat)
|
|
153
|
+
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
|
154
|
+
out_path = new_temp_path("png")
|
|
155
|
+
img.save(out_path)
|
|
156
|
+
image_paths.append(out_path)
|
|
157
|
+
finally:
|
|
158
|
+
pdf_document.close()
|
|
159
|
+
return image_paths
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def cleanup_paths(paths: list[str]) -> None:
|
|
163
|
+
for p in paths:
|
|
164
|
+
_safe_remove(p)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _safe_remove(path: str) -> None:
|
|
168
|
+
try:
|
|
169
|
+
if path and os.path.exists(path):
|
|
170
|
+
os.remove(path)
|
|
171
|
+
except OSError:
|
|
172
|
+
pass
|