mathcraft-ocr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mathcraft_ocr/__init__.py +39 -0
- mathcraft_ocr/__main__.py +6 -0
- mathcraft_ocr/adapters/__init__.py +13 -0
- mathcraft_ocr/adapters/common.py +46 -0
- mathcraft_ocr/adapters/formula_detector.py +131 -0
- mathcraft_ocr/adapters/formula_recognizer.py +151 -0
- mathcraft_ocr/adapters/text_detector.py +57 -0
- mathcraft_ocr/adapters/text_recognizer.py +121 -0
- mathcraft_ocr/api.py +14 -0
- mathcraft_ocr/cache.py +135 -0
- mathcraft_ocr/cli.py +110 -0
- mathcraft_ocr/debug_blocks.py +202 -0
- mathcraft_ocr/doctor.py +50 -0
- mathcraft_ocr/downloader.py +97 -0
- mathcraft_ocr/errors.py +21 -0
- mathcraft_ocr/hardware.py +203 -0
- mathcraft_ocr/image.py +33 -0
- mathcraft_ocr/layout.py +892 -0
- mathcraft_ocr/manifest.py +89 -0
- mathcraft_ocr/manifests/models.v1.json +89 -0
- mathcraft_ocr/providers.py +80 -0
- mathcraft_ocr/results.py +53 -0
- mathcraft_ocr/runtime.py +535 -0
- mathcraft_ocr/serialization.py +120 -0
- mathcraft_ocr/worker.py +131 -0
- mathcraft_ocr-0.1.0.dist-info/METADATA +184 -0
- mathcraft_ocr-0.1.0.dist-info/RECORD +31 -0
- mathcraft_ocr-0.1.0.dist-info/WHEEL +5 -0
- mathcraft_ocr-0.1.0.dist-info/entry_points.txt +3 -0
- mathcraft_ocr-0.1.0.dist-info/licenses/LICENSE +21 -0
- mathcraft_ocr-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__version__ = "0.1.0"
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"DoctorReport",
|
|
9
|
+
"FormulaRecognitionResult",
|
|
10
|
+
"MathCraftBlock",
|
|
11
|
+
"MathCraftError",
|
|
12
|
+
"MathCraftRuntime",
|
|
13
|
+
"MixedRecognitionResult",
|
|
14
|
+
"OCRRegion",
|
|
15
|
+
"__version__",
|
|
16
|
+
"run_doctor",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def __getattr__(name: str) -> object:
|
|
21
|
+
if name in {
|
|
22
|
+
"FormulaRecognitionResult",
|
|
23
|
+
"MathCraftBlock",
|
|
24
|
+
"MathCraftRuntime",
|
|
25
|
+
"MixedRecognitionResult",
|
|
26
|
+
"OCRRegion",
|
|
27
|
+
}:
|
|
28
|
+
from . import api
|
|
29
|
+
|
|
30
|
+
return getattr(api, name)
|
|
31
|
+
if name in {"DoctorReport", "run_doctor"}:
|
|
32
|
+
from . import doctor
|
|
33
|
+
|
|
34
|
+
return getattr(doctor, name)
|
|
35
|
+
if name == "MathCraftError":
|
|
36
|
+
from .errors import MathCraftError
|
|
37
|
+
|
|
38
|
+
return MathCraftError
|
|
39
|
+
raise AttributeError(name)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
from .formula_detector import warmup_formula_detector
|
|
4
|
+
from .formula_recognizer import warmup_formula_recognizer
|
|
5
|
+
from .text_detector import warmup_text_detector
|
|
6
|
+
from .text_recognizer import warmup_pp_text_recognizer
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"warmup_formula_detector",
|
|
10
|
+
"warmup_formula_recognizer",
|
|
11
|
+
"warmup_text_detector",
|
|
12
|
+
"warmup_pp_text_recognizer",
|
|
13
|
+
]
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
from functools import lru_cache
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from ..providers import GPU_PROVIDER_NAMES, ProviderInfo
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _ort():
|
|
13
|
+
return importlib.import_module("onnxruntime")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def session_providers(provider_info: ProviderInfo) -> list[str]:
|
|
17
|
+
available = list(provider_info.available_providers)
|
|
18
|
+
active = provider_info.active_provider
|
|
19
|
+
if active and active in GPU_PROVIDER_NAMES and "CPUExecutionProvider" in available:
|
|
20
|
+
return [active, "CPUExecutionProvider"]
|
|
21
|
+
if "CPUExecutionProvider" in available:
|
|
22
|
+
return ["CPUExecutionProvider"]
|
|
23
|
+
return available
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def create_session(model_path: str | Path, provider_info: ProviderInfo):
|
|
27
|
+
model_path = str(Path(model_path).resolve())
|
|
28
|
+
providers = tuple(session_providers(provider_info))
|
|
29
|
+
session = _create_session_cached(model_path, providers)
|
|
30
|
+
actual = list(session.get_providers() or [])
|
|
31
|
+
active = provider_info.active_provider
|
|
32
|
+
if active and active in GPU_PROVIDER_NAMES and active not in actual:
|
|
33
|
+
raise RuntimeError(
|
|
34
|
+
f"requested ONNX GPU provider {active}, but session providers are {actual}"
|
|
35
|
+
)
|
|
36
|
+
return session
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@lru_cache(maxsize=16)
|
|
40
|
+
def _create_session_cached(model_path: str, providers: tuple[str, ...]):
|
|
41
|
+
ort = _ort()
|
|
42
|
+
return ort.InferenceSession(model_path, providers=list(providers))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def clear_session_cache() -> None:
|
|
46
|
+
_create_session_cached.cache_clear()
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from .common import create_session
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class FormulaBox:
|
|
16
|
+
box: tuple[
|
|
17
|
+
tuple[float, float],
|
|
18
|
+
tuple[float, float],
|
|
19
|
+
tuple[float, float],
|
|
20
|
+
tuple[float, float],
|
|
21
|
+
]
|
|
22
|
+
score: float
|
|
23
|
+
label: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def warmup_formula_detector(model_dir: str | Path, provider_info) -> None:
|
|
27
|
+
root = Path(model_dir)
|
|
28
|
+
candidates = sorted(root.glob("*mfd*.onnx"))
|
|
29
|
+
if not candidates:
|
|
30
|
+
raise FileNotFoundError(f"no mfd onnx file found under {root}")
|
|
31
|
+
create_session(candidates[0], provider_info)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _letterbox(image: np.ndarray, target_size: int = 768) -> tuple[np.ndarray, float, tuple[float, float]]:
|
|
35
|
+
height, width = image.shape[:2]
|
|
36
|
+
scale = min(target_size / width, target_size / height)
|
|
37
|
+
new_w = int(round(width * scale))
|
|
38
|
+
new_h = int(round(height * scale))
|
|
39
|
+
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
|
40
|
+
canvas = np.full((target_size, target_size, 3), 114, dtype=np.uint8)
|
|
41
|
+
pad_x = (target_size - new_w) / 2
|
|
42
|
+
pad_y = (target_size - new_h) / 2
|
|
43
|
+
left = int(round(pad_x - 0.1))
|
|
44
|
+
top = int(round(pad_y - 0.1))
|
|
45
|
+
canvas[top : top + new_h, left : left + new_w] = resized
|
|
46
|
+
return canvas, scale, (float(left), float(top))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _nms_xyxy(boxes: np.ndarray, scores: np.ndarray, iou_threshold: float) -> list[int]:
|
|
50
|
+
if len(boxes) == 0:
|
|
51
|
+
return []
|
|
52
|
+
x1 = boxes[:, 0]
|
|
53
|
+
y1 = boxes[:, 1]
|
|
54
|
+
x2 = boxes[:, 2]
|
|
55
|
+
y2 = boxes[:, 3]
|
|
56
|
+
areas = np.maximum(0.0, x2 - x1) * np.maximum(0.0, y2 - y1)
|
|
57
|
+
order = scores.argsort()[::-1]
|
|
58
|
+
keep: list[int] = []
|
|
59
|
+
while order.size > 0:
|
|
60
|
+
current = int(order[0])
|
|
61
|
+
keep.append(current)
|
|
62
|
+
if order.size == 1:
|
|
63
|
+
break
|
|
64
|
+
rest = order[1:]
|
|
65
|
+
xx1 = np.maximum(x1[current], x1[rest])
|
|
66
|
+
yy1 = np.maximum(y1[current], y1[rest])
|
|
67
|
+
xx2 = np.minimum(x2[current], x2[rest])
|
|
68
|
+
yy2 = np.minimum(y2[current], y2[rest])
|
|
69
|
+
inter_w = np.maximum(0.0, xx2 - xx1)
|
|
70
|
+
inter_h = np.maximum(0.0, yy2 - yy1)
|
|
71
|
+
intersection = inter_w * inter_h
|
|
72
|
+
union = areas[current] + areas[rest] - intersection
|
|
73
|
+
iou = np.divide(intersection, union, out=np.zeros_like(intersection), where=union > 0)
|
|
74
|
+
order = rest[iou <= iou_threshold]
|
|
75
|
+
return keep
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def detect_formula_boxes(
|
|
79
|
+
image_rgb: np.ndarray,
|
|
80
|
+
model_dir: str | Path,
|
|
81
|
+
provider_info,
|
|
82
|
+
*,
|
|
83
|
+
confidence_threshold: float = 0.25,
|
|
84
|
+
iou_threshold: float = 0.45,
|
|
85
|
+
input_size: int = 768,
|
|
86
|
+
) -> tuple[FormulaBox, ...]:
|
|
87
|
+
root = Path(model_dir)
|
|
88
|
+
candidates = sorted(root.glob("*mfd*.onnx"))
|
|
89
|
+
if not candidates:
|
|
90
|
+
raise FileNotFoundError(f"no mfd onnx file found under {root}")
|
|
91
|
+
session = create_session(candidates[0], provider_info)
|
|
92
|
+
preprocessed, scale, (pad_x, pad_y) = _letterbox(image_rgb, input_size)
|
|
93
|
+
model_input = (
|
|
94
|
+
preprocessed.astype(np.float32).transpose(2, 0, 1)[np.newaxis, ...] / 255.0
|
|
95
|
+
)
|
|
96
|
+
output = session.run(None, {session.get_inputs()[0].name: model_input})[0]
|
|
97
|
+
preds = np.asarray(output[0]).T
|
|
98
|
+
if preds.size == 0 or preds.shape[1] < 6:
|
|
99
|
+
return ()
|
|
100
|
+
xywh = preds[:, :4]
|
|
101
|
+
class_scores = preds[:, 4:]
|
|
102
|
+
class_ids = np.argmax(class_scores, axis=1)
|
|
103
|
+
scores = class_scores[np.arange(len(class_scores)), class_ids]
|
|
104
|
+
mask = scores >= confidence_threshold
|
|
105
|
+
if not np.any(mask):
|
|
106
|
+
return ()
|
|
107
|
+
xywh = xywh[mask]
|
|
108
|
+
class_ids = class_ids[mask]
|
|
109
|
+
scores = scores[mask]
|
|
110
|
+
|
|
111
|
+
x, y, w, h = xywh[:, 0], xywh[:, 1], xywh[:, 2], xywh[:, 3]
|
|
112
|
+
boxes = np.stack([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=1)
|
|
113
|
+
boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad_x) / scale
|
|
114
|
+
boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad_y) / scale
|
|
115
|
+
height, width = image_rgb.shape[:2]
|
|
116
|
+
boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, width)
|
|
117
|
+
boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, height)
|
|
118
|
+
|
|
119
|
+
labels = ("embedding", "isolated")
|
|
120
|
+
keep = _nms_xyxy(boxes, scores, iou_threshold)
|
|
121
|
+
results: list[FormulaBox] = []
|
|
122
|
+
for index in keep:
|
|
123
|
+
x1, y1, x2, y2 = boxes[index].tolist()
|
|
124
|
+
results.append(
|
|
125
|
+
FormulaBox(
|
|
126
|
+
box=((x1, y1), (x2, y1), (x2, y2), (x1, y2)),
|
|
127
|
+
score=float(scores[index]),
|
|
128
|
+
label=labels[int(class_ids[index])] if int(class_ids[index]) < len(labels) else str(int(class_ids[index])),
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
return tuple(results)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from functools import lru_cache
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from PIL import Image
|
|
11
|
+
|
|
12
|
+
from .common import create_session
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _softmax(logits: np.ndarray) -> np.ndarray:
|
|
16
|
+
shifted = logits - np.max(logits, axis=-1, keepdims=True)
|
|
17
|
+
exp = np.exp(shifted)
|
|
18
|
+
return exp / np.sum(exp, axis=-1, keepdims=True)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@lru_cache(maxsize=8)
|
|
22
|
+
def _load_processor(model_dir: str):
|
|
23
|
+
from transformers import AutoTokenizer, TrOCRProcessor, ViTImageProcessor
|
|
24
|
+
|
|
25
|
+
image_processor = ViTImageProcessor.from_pretrained(model_dir, use_fast=False)
|
|
26
|
+
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
|
27
|
+
return TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def warmup_formula_recognizer(model_dir: str | Path, provider_info) -> None:
|
|
31
|
+
root = Path(model_dir)
|
|
32
|
+
encoder = root / "encoder_model.onnx"
|
|
33
|
+
decoder = root / "decoder_model.onnx"
|
|
34
|
+
if not encoder.is_file():
|
|
35
|
+
raise FileNotFoundError(f"missing encoder model under {root}")
|
|
36
|
+
if not decoder.is_file():
|
|
37
|
+
raise FileNotFoundError(f"missing decoder model under {root}")
|
|
38
|
+
create_session(encoder, provider_info)
|
|
39
|
+
create_session(decoder, provider_info)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _load_generation_ids(model_dir: Path, tokenizer) -> tuple[int, int | None]:
|
|
43
|
+
decoder_start_id = None
|
|
44
|
+
eos_id = None
|
|
45
|
+
for filename in ("generation_config.json", "config.json"):
|
|
46
|
+
path = model_dir / filename
|
|
47
|
+
if not path.is_file():
|
|
48
|
+
continue
|
|
49
|
+
try:
|
|
50
|
+
data = json.loads(path.read_text(encoding="utf-8-sig"))
|
|
51
|
+
except Exception:
|
|
52
|
+
continue
|
|
53
|
+
decoder_start_id = data.get("decoder_start_token_id", decoder_start_id)
|
|
54
|
+
eos_id = data.get("eos_token_id", eos_id)
|
|
55
|
+
decoder_config = data.get("decoder")
|
|
56
|
+
if isinstance(decoder_config, dict):
|
|
57
|
+
decoder_start_id = decoder_config.get("decoder_start_token_id", decoder_start_id)
|
|
58
|
+
eos_id = decoder_config.get("eos_token_id", eos_id)
|
|
59
|
+
if decoder_start_id is not None and eos_id is not None:
|
|
60
|
+
break
|
|
61
|
+
if decoder_start_id is None:
|
|
62
|
+
decoder_start_id = tokenizer.bos_token_id
|
|
63
|
+
if decoder_start_id is None:
|
|
64
|
+
raise ValueError(f"missing decoder_start_token_id under {model_dir}")
|
|
65
|
+
if eos_id is None:
|
|
66
|
+
eos_id = tokenizer.eos_token_id
|
|
67
|
+
return int(decoder_start_id), int(eos_id) if eos_id is not None else None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def recognize_formula_image(
|
|
71
|
+
image: Image.Image | np.ndarray,
|
|
72
|
+
model_dir: str | Path,
|
|
73
|
+
provider_info,
|
|
74
|
+
*,
|
|
75
|
+
max_new_tokens: int = 256,
|
|
76
|
+
) -> tuple[str, float]:
|
|
77
|
+
return recognize_formula_images(
|
|
78
|
+
[image],
|
|
79
|
+
model_dir,
|
|
80
|
+
provider_info,
|
|
81
|
+
max_new_tokens=max_new_tokens,
|
|
82
|
+
)[0]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def recognize_formula_images(
|
|
86
|
+
images: list[Image.Image | np.ndarray],
|
|
87
|
+
model_dir: str | Path,
|
|
88
|
+
provider_info,
|
|
89
|
+
*,
|
|
90
|
+
max_new_tokens: int = 256,
|
|
91
|
+
) -> list[tuple[str, float]]:
|
|
92
|
+
if not images:
|
|
93
|
+
return []
|
|
94
|
+
root = Path(model_dir)
|
|
95
|
+
processor = _load_processor(str(root))
|
|
96
|
+
encoder_session = create_session(root / "encoder_model.onnx", provider_info)
|
|
97
|
+
decoder_session = create_session(root / "decoder_model.onnx", provider_info)
|
|
98
|
+
|
|
99
|
+
pil_images = [image if isinstance(image, Image.Image) else Image.fromarray(image) for image in images]
|
|
100
|
+
features = processor(images=pil_images, return_tensors="np")
|
|
101
|
+
pixel_values = np.asarray(features["pixel_values"], dtype=np.float32)
|
|
102
|
+
|
|
103
|
+
encoder_input_name = encoder_session.get_inputs()[0].name
|
|
104
|
+
encoder_hidden_states = encoder_session.run(
|
|
105
|
+
None,
|
|
106
|
+
{encoder_input_name: pixel_values},
|
|
107
|
+
)[0]
|
|
108
|
+
|
|
109
|
+
tokenizer = processor.tokenizer
|
|
110
|
+
decoder_start_id, eos_id = _load_generation_ids(root, tokenizer)
|
|
111
|
+
batch_size = len(pil_images)
|
|
112
|
+
input_ids = np.full((batch_size, 1), decoder_start_id, dtype=np.int64)
|
|
113
|
+
token_ids: list[list[int]] = [[] for _ in range(batch_size)]
|
|
114
|
+
token_scores: list[list[float]] = [[] for _ in range(batch_size)]
|
|
115
|
+
finished = np.zeros((batch_size,), dtype=bool)
|
|
116
|
+
pad_after_finish_id = eos_id if eos_id is not None else decoder_start_id
|
|
117
|
+
|
|
118
|
+
for _ in range(max_new_tokens):
|
|
119
|
+
decoder_inputs = {
|
|
120
|
+
decoder_session.get_inputs()[0].name: input_ids,
|
|
121
|
+
decoder_session.get_inputs()[1].name: encoder_hidden_states,
|
|
122
|
+
}
|
|
123
|
+
logits = decoder_session.run(None, decoder_inputs)[0]
|
|
124
|
+
step_logits = logits[:, -1, :]
|
|
125
|
+
step_probs = _softmax(step_logits)
|
|
126
|
+
next_tokens = np.argmax(step_probs, axis=1).astype(np.int64)
|
|
127
|
+
next_column = next_tokens.copy()
|
|
128
|
+
for row, next_token in enumerate(next_tokens.tolist()):
|
|
129
|
+
if finished[row]:
|
|
130
|
+
next_column[row] = pad_after_finish_id
|
|
131
|
+
continue
|
|
132
|
+
next_prob = float(step_probs[row, next_token])
|
|
133
|
+
if eos_id is not None and next_token == eos_id:
|
|
134
|
+
finished[row] = True
|
|
135
|
+
next_column[row] = pad_after_finish_id
|
|
136
|
+
continue
|
|
137
|
+
token_ids[row].append(int(next_token))
|
|
138
|
+
token_scores[row].append(next_prob)
|
|
139
|
+
if bool(np.all(finished)):
|
|
140
|
+
break
|
|
141
|
+
input_ids = np.concatenate(
|
|
142
|
+
[input_ids, next_column.reshape(batch_size, 1)],
|
|
143
|
+
axis=1,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
results: list[tuple[str, float]] = []
|
|
147
|
+
for ids, scores in zip(token_ids, token_scores):
|
|
148
|
+
text = tokenizer.decode(ids, skip_special_tokens=True).strip()
|
|
149
|
+
score = float(sum(scores) / len(scores)) if scores else 0.0
|
|
150
|
+
results.append((text, score))
|
|
151
|
+
return results
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from rapidocr.ch_ppocr_det.utils import DBPostProcess, DetPreProcess
|
|
9
|
+
|
|
10
|
+
from .common import create_session
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _find_detector_model(root: Path) -> Path:
|
|
14
|
+
candidates = sorted(root.glob("**/*det*.onnx"))
|
|
15
|
+
if not candidates:
|
|
16
|
+
raise FileNotFoundError(f"missing text detector model under {root}")
|
|
17
|
+
return candidates[0]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def warmup_text_detector(model_dir: str | Path, provider_info) -> None:
|
|
21
|
+
root = Path(model_dir)
|
|
22
|
+
create_session(_find_detector_model(root), provider_info)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _limit_side_len(image: np.ndarray) -> int:
|
|
26
|
+
max_wh = max(image.shape[0], image.shape[1])
|
|
27
|
+
return min(max_wh, 960)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def detect_text_boxes(
|
|
31
|
+
image_bgr: np.ndarray,
|
|
32
|
+
model_dir: str | Path,
|
|
33
|
+
provider_info,
|
|
34
|
+
) -> tuple[np.ndarray, tuple[float, ...]]:
|
|
35
|
+
root = Path(model_dir)
|
|
36
|
+
model_path = _find_detector_model(root)
|
|
37
|
+
session = create_session(model_path, provider_info)
|
|
38
|
+
pre = DetPreProcess(
|
|
39
|
+
limit_side_len=_limit_side_len(image_bgr),
|
|
40
|
+
limit_type="max",
|
|
41
|
+
mean=[0.5, 0.5, 0.5],
|
|
42
|
+
std=[0.5, 0.5, 0.5],
|
|
43
|
+
)
|
|
44
|
+
post = DBPostProcess(
|
|
45
|
+
thresh=0.3,
|
|
46
|
+
box_thresh=0.5,
|
|
47
|
+
max_candidates=1000,
|
|
48
|
+
unclip_ratio=1.6,
|
|
49
|
+
use_dilation=True,
|
|
50
|
+
)
|
|
51
|
+
model_input = pre(image_bgr)
|
|
52
|
+
outputs = session.run(
|
|
53
|
+
None,
|
|
54
|
+
{session.get_inputs()[0].name: model_input},
|
|
55
|
+
)
|
|
56
|
+
boxes, scores = post(outputs[0], (image_bgr.shape[0], image_bgr.shape[1]))
|
|
57
|
+
return boxes, tuple(float(score) for score in scores)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from rapidocr import EngineType, LangRec, ModelType, OCRVersion
|
|
10
|
+
from rapidocr.ch_ppocr_rec import TextRecInput, TextRecognizer
|
|
11
|
+
from rapidocr.utils.typings import TaskType
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _Config(dict):
|
|
15
|
+
def __init__(self, *args, **kwargs):
|
|
16
|
+
super().__init__()
|
|
17
|
+
data = dict(*args, **kwargs)
|
|
18
|
+
for key, value in data.items():
|
|
19
|
+
if isinstance(value, dict):
|
|
20
|
+
value = _Config(value)
|
|
21
|
+
self[key] = value
|
|
22
|
+
|
|
23
|
+
def __getattr__(self, name):
|
|
24
|
+
try:
|
|
25
|
+
return self[name]
|
|
26
|
+
except KeyError as exc:
|
|
27
|
+
raise AttributeError(name) from exc
|
|
28
|
+
|
|
29
|
+
def __setattr__(self, name, value):
|
|
30
|
+
self[name] = value
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def warmup_pp_text_recognizer(model_dir: str | Path, provider_info) -> None:
|
|
34
|
+
recognizer = _create_pp_text_recognizer(Path(model_dir), provider_info)
|
|
35
|
+
recognizer.rec_batch_num = 1
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def recognize_pp_text_lines(
|
|
39
|
+
images_bgr: list[np.ndarray],
|
|
40
|
+
model_dir: str | Path,
|
|
41
|
+
provider_info,
|
|
42
|
+
*,
|
|
43
|
+
rec_batch_num: int | None = None,
|
|
44
|
+
) -> list[tuple[str, float]]:
|
|
45
|
+
if not images_bgr:
|
|
46
|
+
return []
|
|
47
|
+
recognizer = _create_pp_text_recognizer(Path(model_dir), provider_info)
|
|
48
|
+
max_batch = max(1, int(rec_batch_num or 6))
|
|
49
|
+
recognizer.rec_batch_num = min(max(len(images_bgr), 1), max_batch)
|
|
50
|
+
rec_input = TextRecInput(img=images_bgr, return_word_box=False)
|
|
51
|
+
output = recognizer(rec_input)
|
|
52
|
+
return [(str(text), float(score)) for text, score in zip(output.txts, output.scores)]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _create_pp_text_recognizer(model_dir: Path, provider_info) -> TextRecognizer:
|
|
56
|
+
model_dir = model_dir.resolve()
|
|
57
|
+
use_cuda = bool(getattr(provider_info, "device", "") == "gpu")
|
|
58
|
+
return _create_pp_text_recognizer_cached(str(model_dir), use_cuda)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@lru_cache(maxsize=8)
|
|
62
|
+
def _create_pp_text_recognizer_cached(model_dir: str, use_cuda: bool) -> TextRecognizer:
|
|
63
|
+
model_dir = Path(model_dir)
|
|
64
|
+
model_candidates = sorted(model_dir.glob("**/*rec*.onnx"))
|
|
65
|
+
if not model_candidates:
|
|
66
|
+
raise FileNotFoundError(f"no PP-OCR recognizer onnx file found under {model_dir}")
|
|
67
|
+
model_path = model_candidates[0]
|
|
68
|
+
dict_path = _find_pp_vocab(model_dir)
|
|
69
|
+
if dict_path is None:
|
|
70
|
+
raise FileNotFoundError(f"missing PP-OCR vocabulary under {model_dir}")
|
|
71
|
+
model_name = model_path.name
|
|
72
|
+
is_server = "server" in model_name or "server" in model_dir.name
|
|
73
|
+
is_v5 = "v5" in model_name or "v5" in model_dir.name
|
|
74
|
+
is_english = dict_path.name == "en_dict.txt"
|
|
75
|
+
config = _Config({
|
|
76
|
+
"engine_type": EngineType.ONNXRUNTIME,
|
|
77
|
+
"lang_type": LangRec.EN if is_english else LangRec.CH,
|
|
78
|
+
"model_type": ModelType.SERVER if is_server else ModelType.MOBILE,
|
|
79
|
+
"ocr_version": OCRVersion.PPOCRV5 if is_v5 else OCRVersion.PPOCRV4,
|
|
80
|
+
"task_type": TaskType.REC,
|
|
81
|
+
"model_path": str(model_path),
|
|
82
|
+
"model_dir": None,
|
|
83
|
+
"rec_keys_path": str(dict_path),
|
|
84
|
+
"rec_img_shape": [3, 48, 320],
|
|
85
|
+
"rec_batch_num": 6,
|
|
86
|
+
"font_path": None,
|
|
87
|
+
"engine_cfg": {
|
|
88
|
+
"intra_op_num_threads": -1,
|
|
89
|
+
"inter_op_num_threads": -1,
|
|
90
|
+
"enable_cpu_mem_arena": False,
|
|
91
|
+
"cpu_ep_cfg": {"arena_extend_strategy": "kSameAsRequested"},
|
|
92
|
+
"use_cuda": use_cuda,
|
|
93
|
+
"cuda_ep_cfg": {
|
|
94
|
+
"device_id": 0,
|
|
95
|
+
"arena_extend_strategy": "kNextPowerOfTwo",
|
|
96
|
+
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
|
97
|
+
"do_copy_in_default_stream": True,
|
|
98
|
+
},
|
|
99
|
+
"use_dml": False,
|
|
100
|
+
"dm_ep_cfg": None,
|
|
101
|
+
"use_cann": False,
|
|
102
|
+
"cann_ep_cfg": {
|
|
103
|
+
"device_id": 0,
|
|
104
|
+
"arena_extend_strategy": "kNextPowerOfTwo",
|
|
105
|
+
"npu_mem_limit": 21474836480,
|
|
106
|
+
"op_select_impl_mode": "high_performance",
|
|
107
|
+
"optypelist_for_implmode": "Gelu",
|
|
108
|
+
"enable_cann_graph": True,
|
|
109
|
+
},
|
|
110
|
+
},
|
|
111
|
+
})
|
|
112
|
+
return TextRecognizer(config)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def clear_text_recognizer_cache() -> None:
|
|
116
|
+
_create_pp_text_recognizer_cached.cache_clear()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _find_pp_vocab(model_dir: Path) -> Path | None:
|
|
120
|
+
candidate = model_dir / "ppocrv5_keys.txt"
|
|
121
|
+
return candidate if candidate.is_file() else None
|
mathcraft_ocr/api.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
from .results import FormulaRecognitionResult, MathCraftBlock, MixedRecognitionResult, OCRRegion
|
|
4
|
+
from .runtime import MathCraftRuntime, WarmupComponentStatus, WarmupPlan
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"FormulaRecognitionResult",
|
|
8
|
+
"MathCraftBlock",
|
|
9
|
+
"MathCraftRuntime",
|
|
10
|
+
"MixedRecognitionResult",
|
|
11
|
+
"OCRRegion",
|
|
12
|
+
"WarmupComponentStatus",
|
|
13
|
+
"WarmupPlan",
|
|
14
|
+
]
|