mathcraft-ocr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,39 @@
1
+ # coding: utf-8
2
+
3
+ from __future__ import annotations
4
+
5
+ __version__ = "0.1.0"
6
+
7
+ __all__ = [
8
+ "DoctorReport",
9
+ "FormulaRecognitionResult",
10
+ "MathCraftBlock",
11
+ "MathCraftError",
12
+ "MathCraftRuntime",
13
+ "MixedRecognitionResult",
14
+ "OCRRegion",
15
+ "__version__",
16
+ "run_doctor",
17
+ ]
18
+
19
+
20
+ def __getattr__(name: str) -> object:
21
+ if name in {
22
+ "FormulaRecognitionResult",
23
+ "MathCraftBlock",
24
+ "MathCraftRuntime",
25
+ "MixedRecognitionResult",
26
+ "OCRRegion",
27
+ }:
28
+ from . import api
29
+
30
+ return getattr(api, name)
31
+ if name in {"DoctorReport", "run_doctor"}:
32
+ from . import doctor
33
+
34
+ return getattr(doctor, name)
35
+ if name == "MathCraftError":
36
+ from .errors import MathCraftError
37
+
38
+ return MathCraftError
39
+ raise AttributeError(name)
@@ -0,0 +1,6 @@
1
+ # coding: utf-8
2
+
3
+ from .cli import main
4
+
5
+
6
+ raise SystemExit(main())
@@ -0,0 +1,13 @@
1
+ # coding: utf-8
2
+
3
+ from .formula_detector import warmup_formula_detector
4
+ from .formula_recognizer import warmup_formula_recognizer
5
+ from .text_detector import warmup_text_detector
6
+ from .text_recognizer import warmup_pp_text_recognizer
7
+
8
+ __all__ = [
9
+ "warmup_formula_detector",
10
+ "warmup_formula_recognizer",
11
+ "warmup_text_detector",
12
+ "warmup_pp_text_recognizer",
13
+ ]
@@ -0,0 +1,46 @@
1
+ # coding: utf-8
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+
9
+ from ..providers import GPU_PROVIDER_NAMES, ProviderInfo
10
+
11
+
12
+ def _ort():
13
+ return importlib.import_module("onnxruntime")
14
+
15
+
16
+ def session_providers(provider_info: ProviderInfo) -> list[str]:
17
+ available = list(provider_info.available_providers)
18
+ active = provider_info.active_provider
19
+ if active and active in GPU_PROVIDER_NAMES and "CPUExecutionProvider" in available:
20
+ return [active, "CPUExecutionProvider"]
21
+ if "CPUExecutionProvider" in available:
22
+ return ["CPUExecutionProvider"]
23
+ return available
24
+
25
+
26
+ def create_session(model_path: str | Path, provider_info: ProviderInfo):
27
+ model_path = str(Path(model_path).resolve())
28
+ providers = tuple(session_providers(provider_info))
29
+ session = _create_session_cached(model_path, providers)
30
+ actual = list(session.get_providers() or [])
31
+ active = provider_info.active_provider
32
+ if active and active in GPU_PROVIDER_NAMES and active not in actual:
33
+ raise RuntimeError(
34
+ f"requested ONNX GPU provider {active}, but session providers are {actual}"
35
+ )
36
+ return session
37
+
38
+
39
+ @lru_cache(maxsize=16)
40
+ def _create_session_cached(model_path: str, providers: tuple[str, ...]):
41
+ ort = _ort()
42
+ return ort.InferenceSession(model_path, providers=list(providers))
43
+
44
+
45
+ def clear_session_cache() -> None:
46
+ _create_session_cached.cache_clear()
@@ -0,0 +1,131 @@
1
+ # coding: utf-8
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from .common import create_session
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class FormulaBox:
16
+ box: tuple[
17
+ tuple[float, float],
18
+ tuple[float, float],
19
+ tuple[float, float],
20
+ tuple[float, float],
21
+ ]
22
+ score: float
23
+ label: str
24
+
25
+
26
+ def warmup_formula_detector(model_dir: str | Path, provider_info) -> None:
27
+ root = Path(model_dir)
28
+ candidates = sorted(root.glob("*mfd*.onnx"))
29
+ if not candidates:
30
+ raise FileNotFoundError(f"no mfd onnx file found under {root}")
31
+ create_session(candidates[0], provider_info)
32
+
33
+
34
+ def _letterbox(image: np.ndarray, target_size: int = 768) -> tuple[np.ndarray, float, tuple[float, float]]:
35
+ height, width = image.shape[:2]
36
+ scale = min(target_size / width, target_size / height)
37
+ new_w = int(round(width * scale))
38
+ new_h = int(round(height * scale))
39
+ resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
40
+ canvas = np.full((target_size, target_size, 3), 114, dtype=np.uint8)
41
+ pad_x = (target_size - new_w) / 2
42
+ pad_y = (target_size - new_h) / 2
43
+ left = int(round(pad_x - 0.1))
44
+ top = int(round(pad_y - 0.1))
45
+ canvas[top : top + new_h, left : left + new_w] = resized
46
+ return canvas, scale, (float(left), float(top))
47
+
48
+
49
+ def _nms_xyxy(boxes: np.ndarray, scores: np.ndarray, iou_threshold: float) -> list[int]:
50
+ if len(boxes) == 0:
51
+ return []
52
+ x1 = boxes[:, 0]
53
+ y1 = boxes[:, 1]
54
+ x2 = boxes[:, 2]
55
+ y2 = boxes[:, 3]
56
+ areas = np.maximum(0.0, x2 - x1) * np.maximum(0.0, y2 - y1)
57
+ order = scores.argsort()[::-1]
58
+ keep: list[int] = []
59
+ while order.size > 0:
60
+ current = int(order[0])
61
+ keep.append(current)
62
+ if order.size == 1:
63
+ break
64
+ rest = order[1:]
65
+ xx1 = np.maximum(x1[current], x1[rest])
66
+ yy1 = np.maximum(y1[current], y1[rest])
67
+ xx2 = np.minimum(x2[current], x2[rest])
68
+ yy2 = np.minimum(y2[current], y2[rest])
69
+ inter_w = np.maximum(0.0, xx2 - xx1)
70
+ inter_h = np.maximum(0.0, yy2 - yy1)
71
+ intersection = inter_w * inter_h
72
+ union = areas[current] + areas[rest] - intersection
73
+ iou = np.divide(intersection, union, out=np.zeros_like(intersection), where=union > 0)
74
+ order = rest[iou <= iou_threshold]
75
+ return keep
76
+
77
+
78
+ def detect_formula_boxes(
79
+ image_rgb: np.ndarray,
80
+ model_dir: str | Path,
81
+ provider_info,
82
+ *,
83
+ confidence_threshold: float = 0.25,
84
+ iou_threshold: float = 0.45,
85
+ input_size: int = 768,
86
+ ) -> tuple[FormulaBox, ...]:
87
+ root = Path(model_dir)
88
+ candidates = sorted(root.glob("*mfd*.onnx"))
89
+ if not candidates:
90
+ raise FileNotFoundError(f"no mfd onnx file found under {root}")
91
+ session = create_session(candidates[0], provider_info)
92
+ preprocessed, scale, (pad_x, pad_y) = _letterbox(image_rgb, input_size)
93
+ model_input = (
94
+ preprocessed.astype(np.float32).transpose(2, 0, 1)[np.newaxis, ...] / 255.0
95
+ )
96
+ output = session.run(None, {session.get_inputs()[0].name: model_input})[0]
97
+ preds = np.asarray(output[0]).T
98
+ if preds.size == 0 or preds.shape[1] < 6:
99
+ return ()
100
+ xywh = preds[:, :4]
101
+ class_scores = preds[:, 4:]
102
+ class_ids = np.argmax(class_scores, axis=1)
103
+ scores = class_scores[np.arange(len(class_scores)), class_ids]
104
+ mask = scores >= confidence_threshold
105
+ if not np.any(mask):
106
+ return ()
107
+ xywh = xywh[mask]
108
+ class_ids = class_ids[mask]
109
+ scores = scores[mask]
110
+
111
+ x, y, w, h = xywh[:, 0], xywh[:, 1], xywh[:, 2], xywh[:, 3]
112
+ boxes = np.stack([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=1)
113
+ boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad_x) / scale
114
+ boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad_y) / scale
115
+ height, width = image_rgb.shape[:2]
116
+ boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, width)
117
+ boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, height)
118
+
119
+ labels = ("embedding", "isolated")
120
+ keep = _nms_xyxy(boxes, scores, iou_threshold)
121
+ results: list[FormulaBox] = []
122
+ for index in keep:
123
+ x1, y1, x2, y2 = boxes[index].tolist()
124
+ results.append(
125
+ FormulaBox(
126
+ box=((x1, y1), (x2, y1), (x2, y2), (x1, y2)),
127
+ score=float(scores[index]),
128
+ label=labels[int(class_ids[index])] if int(class_ids[index]) < len(labels) else str(int(class_ids[index])),
129
+ )
130
+ )
131
+ return tuple(results)
@@ -0,0 +1,151 @@
1
+ # coding: utf-8
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ from .common import create_session
13
+
14
+
15
+ def _softmax(logits: np.ndarray) -> np.ndarray:
16
+ shifted = logits - np.max(logits, axis=-1, keepdims=True)
17
+ exp = np.exp(shifted)
18
+ return exp / np.sum(exp, axis=-1, keepdims=True)
19
+
20
+
21
+ @lru_cache(maxsize=8)
22
+ def _load_processor(model_dir: str):
23
+ from transformers import AutoTokenizer, TrOCRProcessor, ViTImageProcessor
24
+
25
+ image_processor = ViTImageProcessor.from_pretrained(model_dir, use_fast=False)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
27
+ return TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)
28
+
29
+
30
+ def warmup_formula_recognizer(model_dir: str | Path, provider_info) -> None:
31
+ root = Path(model_dir)
32
+ encoder = root / "encoder_model.onnx"
33
+ decoder = root / "decoder_model.onnx"
34
+ if not encoder.is_file():
35
+ raise FileNotFoundError(f"missing encoder model under {root}")
36
+ if not decoder.is_file():
37
+ raise FileNotFoundError(f"missing decoder model under {root}")
38
+ create_session(encoder, provider_info)
39
+ create_session(decoder, provider_info)
40
+
41
+
42
+ def _load_generation_ids(model_dir: Path, tokenizer) -> tuple[int, int | None]:
43
+ decoder_start_id = None
44
+ eos_id = None
45
+ for filename in ("generation_config.json", "config.json"):
46
+ path = model_dir / filename
47
+ if not path.is_file():
48
+ continue
49
+ try:
50
+ data = json.loads(path.read_text(encoding="utf-8-sig"))
51
+ except Exception:
52
+ continue
53
+ decoder_start_id = data.get("decoder_start_token_id", decoder_start_id)
54
+ eos_id = data.get("eos_token_id", eos_id)
55
+ decoder_config = data.get("decoder")
56
+ if isinstance(decoder_config, dict):
57
+ decoder_start_id = decoder_config.get("decoder_start_token_id", decoder_start_id)
58
+ eos_id = decoder_config.get("eos_token_id", eos_id)
59
+ if decoder_start_id is not None and eos_id is not None:
60
+ break
61
+ if decoder_start_id is None:
62
+ decoder_start_id = tokenizer.bos_token_id
63
+ if decoder_start_id is None:
64
+ raise ValueError(f"missing decoder_start_token_id under {model_dir}")
65
+ if eos_id is None:
66
+ eos_id = tokenizer.eos_token_id
67
+ return int(decoder_start_id), int(eos_id) if eos_id is not None else None
68
+
69
+
70
+ def recognize_formula_image(
71
+ image: Image.Image | np.ndarray,
72
+ model_dir: str | Path,
73
+ provider_info,
74
+ *,
75
+ max_new_tokens: int = 256,
76
+ ) -> tuple[str, float]:
77
+ return recognize_formula_images(
78
+ [image],
79
+ model_dir,
80
+ provider_info,
81
+ max_new_tokens=max_new_tokens,
82
+ )[0]
83
+
84
+
85
+ def recognize_formula_images(
86
+ images: list[Image.Image | np.ndarray],
87
+ model_dir: str | Path,
88
+ provider_info,
89
+ *,
90
+ max_new_tokens: int = 256,
91
+ ) -> list[tuple[str, float]]:
92
+ if not images:
93
+ return []
94
+ root = Path(model_dir)
95
+ processor = _load_processor(str(root))
96
+ encoder_session = create_session(root / "encoder_model.onnx", provider_info)
97
+ decoder_session = create_session(root / "decoder_model.onnx", provider_info)
98
+
99
+ pil_images = [image if isinstance(image, Image.Image) else Image.fromarray(image) for image in images]
100
+ features = processor(images=pil_images, return_tensors="np")
101
+ pixel_values = np.asarray(features["pixel_values"], dtype=np.float32)
102
+
103
+ encoder_input_name = encoder_session.get_inputs()[0].name
104
+ encoder_hidden_states = encoder_session.run(
105
+ None,
106
+ {encoder_input_name: pixel_values},
107
+ )[0]
108
+
109
+ tokenizer = processor.tokenizer
110
+ decoder_start_id, eos_id = _load_generation_ids(root, tokenizer)
111
+ batch_size = len(pil_images)
112
+ input_ids = np.full((batch_size, 1), decoder_start_id, dtype=np.int64)
113
+ token_ids: list[list[int]] = [[] for _ in range(batch_size)]
114
+ token_scores: list[list[float]] = [[] for _ in range(batch_size)]
115
+ finished = np.zeros((batch_size,), dtype=bool)
116
+ pad_after_finish_id = eos_id if eos_id is not None else decoder_start_id
117
+
118
+ for _ in range(max_new_tokens):
119
+ decoder_inputs = {
120
+ decoder_session.get_inputs()[0].name: input_ids,
121
+ decoder_session.get_inputs()[1].name: encoder_hidden_states,
122
+ }
123
+ logits = decoder_session.run(None, decoder_inputs)[0]
124
+ step_logits = logits[:, -1, :]
125
+ step_probs = _softmax(step_logits)
126
+ next_tokens = np.argmax(step_probs, axis=1).astype(np.int64)
127
+ next_column = next_tokens.copy()
128
+ for row, next_token in enumerate(next_tokens.tolist()):
129
+ if finished[row]:
130
+ next_column[row] = pad_after_finish_id
131
+ continue
132
+ next_prob = float(step_probs[row, next_token])
133
+ if eos_id is not None and next_token == eos_id:
134
+ finished[row] = True
135
+ next_column[row] = pad_after_finish_id
136
+ continue
137
+ token_ids[row].append(int(next_token))
138
+ token_scores[row].append(next_prob)
139
+ if bool(np.all(finished)):
140
+ break
141
+ input_ids = np.concatenate(
142
+ [input_ids, next_column.reshape(batch_size, 1)],
143
+ axis=1,
144
+ )
145
+
146
+ results: list[tuple[str, float]] = []
147
+ for ids, scores in zip(token_ids, token_scores):
148
+ text = tokenizer.decode(ids, skip_special_tokens=True).strip()
149
+ score = float(sum(scores) / len(scores)) if scores else 0.0
150
+ results.append((text, score))
151
+ return results
@@ -0,0 +1,57 @@
1
+ # coding: utf-8
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ from rapidocr.ch_ppocr_det.utils import DBPostProcess, DetPreProcess
9
+
10
+ from .common import create_session
11
+
12
+
13
+ def _find_detector_model(root: Path) -> Path:
14
+ candidates = sorted(root.glob("**/*det*.onnx"))
15
+ if not candidates:
16
+ raise FileNotFoundError(f"missing text detector model under {root}")
17
+ return candidates[0]
18
+
19
+
20
+ def warmup_text_detector(model_dir: str | Path, provider_info) -> None:
21
+ root = Path(model_dir)
22
+ create_session(_find_detector_model(root), provider_info)
23
+
24
+
25
+ def _limit_side_len(image: np.ndarray) -> int:
26
+ max_wh = max(image.shape[0], image.shape[1])
27
+ return min(max_wh, 960)
28
+
29
+
30
+ def detect_text_boxes(
31
+ image_bgr: np.ndarray,
32
+ model_dir: str | Path,
33
+ provider_info,
34
+ ) -> tuple[np.ndarray, tuple[float, ...]]:
35
+ root = Path(model_dir)
36
+ model_path = _find_detector_model(root)
37
+ session = create_session(model_path, provider_info)
38
+ pre = DetPreProcess(
39
+ limit_side_len=_limit_side_len(image_bgr),
40
+ limit_type="max",
41
+ mean=[0.5, 0.5, 0.5],
42
+ std=[0.5, 0.5, 0.5],
43
+ )
44
+ post = DBPostProcess(
45
+ thresh=0.3,
46
+ box_thresh=0.5,
47
+ max_candidates=1000,
48
+ unclip_ratio=1.6,
49
+ use_dilation=True,
50
+ )
51
+ model_input = pre(image_bgr)
52
+ outputs = session.run(
53
+ None,
54
+ {session.get_inputs()[0].name: model_input},
55
+ )
56
+ boxes, scores = post(outputs[0], (image_bgr.shape[0], image_bgr.shape[1]))
57
+ return boxes, tuple(float(score) for score in scores)
@@ -0,0 +1,121 @@
1
+ # coding: utf-8
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import lru_cache
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ from rapidocr import EngineType, LangRec, ModelType, OCRVersion
10
+ from rapidocr.ch_ppocr_rec import TextRecInput, TextRecognizer
11
+ from rapidocr.utils.typings import TaskType
12
+
13
+
14
+ class _Config(dict):
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__()
17
+ data = dict(*args, **kwargs)
18
+ for key, value in data.items():
19
+ if isinstance(value, dict):
20
+ value = _Config(value)
21
+ self[key] = value
22
+
23
+ def __getattr__(self, name):
24
+ try:
25
+ return self[name]
26
+ except KeyError as exc:
27
+ raise AttributeError(name) from exc
28
+
29
+ def __setattr__(self, name, value):
30
+ self[name] = value
31
+
32
+
33
+ def warmup_pp_text_recognizer(model_dir: str | Path, provider_info) -> None:
34
+ recognizer = _create_pp_text_recognizer(Path(model_dir), provider_info)
35
+ recognizer.rec_batch_num = 1
36
+
37
+
38
+ def recognize_pp_text_lines(
39
+ images_bgr: list[np.ndarray],
40
+ model_dir: str | Path,
41
+ provider_info,
42
+ *,
43
+ rec_batch_num: int | None = None,
44
+ ) -> list[tuple[str, float]]:
45
+ if not images_bgr:
46
+ return []
47
+ recognizer = _create_pp_text_recognizer(Path(model_dir), provider_info)
48
+ max_batch = max(1, int(rec_batch_num or 6))
49
+ recognizer.rec_batch_num = min(max(len(images_bgr), 1), max_batch)
50
+ rec_input = TextRecInput(img=images_bgr, return_word_box=False)
51
+ output = recognizer(rec_input)
52
+ return [(str(text), float(score)) for text, score in zip(output.txts, output.scores)]
53
+
54
+
55
+ def _create_pp_text_recognizer(model_dir: Path, provider_info) -> TextRecognizer:
56
+ model_dir = model_dir.resolve()
57
+ use_cuda = bool(getattr(provider_info, "device", "") == "gpu")
58
+ return _create_pp_text_recognizer_cached(str(model_dir), use_cuda)
59
+
60
+
61
+ @lru_cache(maxsize=8)
62
+ def _create_pp_text_recognizer_cached(model_dir: str, use_cuda: bool) -> TextRecognizer:
63
+ model_dir = Path(model_dir)
64
+ model_candidates = sorted(model_dir.glob("**/*rec*.onnx"))
65
+ if not model_candidates:
66
+ raise FileNotFoundError(f"no PP-OCR recognizer onnx file found under {model_dir}")
67
+ model_path = model_candidates[0]
68
+ dict_path = _find_pp_vocab(model_dir)
69
+ if dict_path is None:
70
+ raise FileNotFoundError(f"missing PP-OCR vocabulary under {model_dir}")
71
+ model_name = model_path.name
72
+ is_server = "server" in model_name or "server" in model_dir.name
73
+ is_v5 = "v5" in model_name or "v5" in model_dir.name
74
+ is_english = dict_path.name == "en_dict.txt"
75
+ config = _Config({
76
+ "engine_type": EngineType.ONNXRUNTIME,
77
+ "lang_type": LangRec.EN if is_english else LangRec.CH,
78
+ "model_type": ModelType.SERVER if is_server else ModelType.MOBILE,
79
+ "ocr_version": OCRVersion.PPOCRV5 if is_v5 else OCRVersion.PPOCRV4,
80
+ "task_type": TaskType.REC,
81
+ "model_path": str(model_path),
82
+ "model_dir": None,
83
+ "rec_keys_path": str(dict_path),
84
+ "rec_img_shape": [3, 48, 320],
85
+ "rec_batch_num": 6,
86
+ "font_path": None,
87
+ "engine_cfg": {
88
+ "intra_op_num_threads": -1,
89
+ "inter_op_num_threads": -1,
90
+ "enable_cpu_mem_arena": False,
91
+ "cpu_ep_cfg": {"arena_extend_strategy": "kSameAsRequested"},
92
+ "use_cuda": use_cuda,
93
+ "cuda_ep_cfg": {
94
+ "device_id": 0,
95
+ "arena_extend_strategy": "kNextPowerOfTwo",
96
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
97
+ "do_copy_in_default_stream": True,
98
+ },
99
+ "use_dml": False,
100
+ "dm_ep_cfg": None,
101
+ "use_cann": False,
102
+ "cann_ep_cfg": {
103
+ "device_id": 0,
104
+ "arena_extend_strategy": "kNextPowerOfTwo",
105
+ "npu_mem_limit": 21474836480,
106
+ "op_select_impl_mode": "high_performance",
107
+ "optypelist_for_implmode": "Gelu",
108
+ "enable_cann_graph": True,
109
+ },
110
+ },
111
+ })
112
+ return TextRecognizer(config)
113
+
114
+
115
+ def clear_text_recognizer_cache() -> None:
116
+ _create_pp_text_recognizer_cached.cache_clear()
117
+
118
+
119
+ def _find_pp_vocab(model_dir: Path) -> Path | None:
120
+ candidate = model_dir / "ppocrv5_keys.txt"
121
+ return candidate if candidate.is_file() else None
mathcraft_ocr/api.py ADDED
@@ -0,0 +1,14 @@
1
+ # coding: utf-8
2
+
3
+ from .results import FormulaRecognitionResult, MathCraftBlock, MixedRecognitionResult, OCRRegion
4
+ from .runtime import MathCraftRuntime, WarmupComponentStatus, WarmupPlan
5
+
6
+ __all__ = [
7
+ "FormulaRecognitionResult",
8
+ "MathCraftBlock",
9
+ "MathCraftRuntime",
10
+ "MixedRecognitionResult",
11
+ "OCRRegion",
12
+ "WarmupComponentStatus",
13
+ "WarmupPlan",
14
+ ]