doc-page-extractor 0.2.4__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of doc-page-extractor might be problematic. Click here for more details.
- doc_page_extractor/__init__.py +16 -0
- doc_page_extractor/clipper.py +119 -0
- doc_page_extractor/downloader.py +16 -0
- doc_page_extractor/extractor.py +218 -0
- doc_page_extractor/latex.py +33 -0
- doc_page_extractor/layout_order.py +239 -0
- doc_page_extractor/layoutreader.py +126 -0
- doc_page_extractor/model.py +133 -0
- doc_page_extractor/ocr.py +196 -0
- doc_page_extractor/ocr_corrector.py +126 -0
- doc_page_extractor/onnxocr/__init__.py +1 -0
- doc_page_extractor/onnxocr/cls_postprocess.py +26 -0
- doc_page_extractor/onnxocr/db_postprocess.py +246 -0
- doc_page_extractor/onnxocr/imaug.py +32 -0
- doc_page_extractor/onnxocr/operators.py +187 -0
- doc_page_extractor/onnxocr/predict_base.py +57 -0
- doc_page_extractor/onnxocr/predict_cls.py +109 -0
- doc_page_extractor/onnxocr/predict_det.py +139 -0
- doc_page_extractor/onnxocr/predict_rec.py +344 -0
- doc_page_extractor/onnxocr/predict_system.py +97 -0
- doc_page_extractor/onnxocr/rec_postprocess.py +896 -0
- doc_page_extractor/onnxocr/utils.py +71 -0
- doc_page_extractor/overlap.py +167 -0
- doc_page_extractor/plot.py +93 -0
- doc_page_extractor/raw_optimizer.py +104 -0
- doc_page_extractor/rectangle.py +72 -0
- doc_page_extractor/rotation.py +158 -0
- doc_page_extractor/table.py +60 -0
- doc_page_extractor/types.py +68 -0
- doc_page_extractor/utils.py +32 -0
- doc_page_extractor-0.2.4.dist-info/LICENSE +661 -0
- doc_page_extractor-0.2.4.dist-info/METADATA +88 -0
- doc_page_extractor-0.2.4.dist-info/RECORD +34 -0
- doc_page_extractor-0.2.4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# Copy from https://github.com/ppaanngggg/layoutreader/blob/main/v3/helpers.py
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import List, Dict
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from transformers import LayoutLMv3ForTokenClassification
|
|
7
|
+
|
|
8
|
+
MAX_LEN = 510
|
|
9
|
+
CLS_TOKEN_ID = 0
|
|
10
|
+
UNK_TOKEN_ID = 3
|
|
11
|
+
EOS_TOKEN_ID = 2
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DataCollator:
|
|
15
|
+
def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
|
|
16
|
+
bbox = []
|
|
17
|
+
labels = []
|
|
18
|
+
input_ids = []
|
|
19
|
+
attention_mask = []
|
|
20
|
+
|
|
21
|
+
# clip bbox and labels to max length, build input_ids and attention_mask
|
|
22
|
+
for feature in features:
|
|
23
|
+
_bbox = feature["source_boxes"]
|
|
24
|
+
if len(_bbox) > MAX_LEN:
|
|
25
|
+
_bbox = _bbox[:MAX_LEN]
|
|
26
|
+
_labels = feature["target_index"]
|
|
27
|
+
if len(_labels) > MAX_LEN:
|
|
28
|
+
_labels = _labels[:MAX_LEN]
|
|
29
|
+
_input_ids = [UNK_TOKEN_ID] * len(_bbox)
|
|
30
|
+
_attention_mask = [1] * len(_bbox)
|
|
31
|
+
assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
|
|
32
|
+
bbox.append(_bbox)
|
|
33
|
+
labels.append(_labels)
|
|
34
|
+
input_ids.append(_input_ids)
|
|
35
|
+
attention_mask.append(_attention_mask)
|
|
36
|
+
|
|
37
|
+
# add CLS and EOS tokens
|
|
38
|
+
for i in range(len(bbox)):
|
|
39
|
+
bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
|
|
40
|
+
labels[i] = [-100] + labels[i] + [-100]
|
|
41
|
+
input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
|
|
42
|
+
attention_mask[i] = [1] + attention_mask[i] + [1]
|
|
43
|
+
|
|
44
|
+
# padding to max length
|
|
45
|
+
max_len = max(len(x) for x in bbox)
|
|
46
|
+
for i in range(len(bbox)):
|
|
47
|
+
bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
|
|
48
|
+
labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
|
|
49
|
+
input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
|
|
50
|
+
attention_mask[i] = attention_mask[i] + [0] * (
|
|
51
|
+
max_len - len(attention_mask[i])
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
ret = {
|
|
55
|
+
"bbox": torch.tensor(bbox),
|
|
56
|
+
"attention_mask": torch.tensor(attention_mask),
|
|
57
|
+
"labels": torch.tensor(labels),
|
|
58
|
+
"input_ids": torch.tensor(input_ids),
|
|
59
|
+
}
|
|
60
|
+
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
|
|
61
|
+
ret["labels"][ret["labels"] > MAX_LEN] = -100
|
|
62
|
+
# set label > 0 to label-1, because original labels are 1-indexed
|
|
63
|
+
ret["labels"][ret["labels"] > 0] -= 1
|
|
64
|
+
return ret
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def boxes2inputs(boxes: List[List[float]]) -> Dict[str, torch.Tensor]:
|
|
68
|
+
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
|
|
69
|
+
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
|
|
70
|
+
attention_mask = [1] + [1] * len(boxes) + [1]
|
|
71
|
+
return {
|
|
72
|
+
"bbox": torch.tensor([bbox]),
|
|
73
|
+
"attention_mask": torch.tensor([attention_mask]),
|
|
74
|
+
"input_ids": torch.tensor([input_ids]),
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def prepare_inputs(
|
|
79
|
+
inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
|
|
80
|
+
) -> Dict[str, torch.Tensor]:
|
|
81
|
+
ret = {}
|
|
82
|
+
for k, v in inputs.items():
|
|
83
|
+
v = v.to(model.device)
|
|
84
|
+
if torch.is_floating_point(v):
|
|
85
|
+
v = v.to(model.dtype)
|
|
86
|
+
ret[k] = v
|
|
87
|
+
return ret
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
|
|
91
|
+
"""
|
|
92
|
+
parse logits to orders
|
|
93
|
+
|
|
94
|
+
:param logits: logits from model
|
|
95
|
+
:param length: input length
|
|
96
|
+
:return: orders
|
|
97
|
+
"""
|
|
98
|
+
logits = logits[1 : length + 1, :length]
|
|
99
|
+
orders = logits.argsort(descending=False).tolist()
|
|
100
|
+
ret = [o.pop() for o in orders]
|
|
101
|
+
while True:
|
|
102
|
+
order_to_idxes = defaultdict(list)
|
|
103
|
+
for idx, order in enumerate(ret):
|
|
104
|
+
order_to_idxes[order].append(idx)
|
|
105
|
+
# filter idxes len > 1
|
|
106
|
+
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
|
|
107
|
+
if not order_to_idxes:
|
|
108
|
+
break
|
|
109
|
+
# filter
|
|
110
|
+
for order, idxes in order_to_idxes.items():
|
|
111
|
+
# find original logits of idxes
|
|
112
|
+
idxes_to_logit = {}
|
|
113
|
+
for idx in idxes:
|
|
114
|
+
idxes_to_logit[idx] = logits[idx, order]
|
|
115
|
+
idxes_to_logit = sorted(
|
|
116
|
+
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
|
|
117
|
+
)
|
|
118
|
+
# keep the highest logit as order, set others to next candidate
|
|
119
|
+
for idx, _ in idxes_to_logit[1:]:
|
|
120
|
+
ret[idx] = orders[idx].pop()
|
|
121
|
+
|
|
122
|
+
return ret
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def check_duplicate(a: List[int]) -> bool:
|
|
126
|
+
return len(a) != len(set(a))
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from os import PathLike
|
|
2
|
+
from time import sleep
|
|
3
|
+
from typing import cast, runtime_checkable, Protocol
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from threading import Lock
|
|
6
|
+
from huggingface_hub import hf_hub_download, snapshot_download, try_to_load_from_cache
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
_RETRY_TIMES = 6
|
|
10
|
+
_RETRY_SLEEP = 3.5
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class Model(Protocol):
|
|
14
|
+
def get_onnx_ocr_path(self) -> Path:
|
|
15
|
+
raise NotImplementedError()
|
|
16
|
+
|
|
17
|
+
def get_yolo_path(self) -> Path:
|
|
18
|
+
raise NotImplementedError()
|
|
19
|
+
|
|
20
|
+
def get_layoutreader_path(self) -> Path:
|
|
21
|
+
raise NotImplementedError()
|
|
22
|
+
|
|
23
|
+
def get_struct_eqtable_path(self) -> Path:
|
|
24
|
+
raise NotImplementedError()
|
|
25
|
+
|
|
26
|
+
def get_latex_path(self) -> Path:
|
|
27
|
+
raise NotImplementedError()
|
|
28
|
+
|
|
29
|
+
class HuggingfaceModel(Model):
|
|
30
|
+
def __init__(self, model_cache_dir: PathLike):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self._lock: Lock = Lock()
|
|
33
|
+
self._model_cache_dir: Path = Path(model_cache_dir)
|
|
34
|
+
|
|
35
|
+
def get_onnx_ocr_path(self) -> Path:
|
|
36
|
+
return self._get_model_path(
|
|
37
|
+
repo_id="moskize/OnnxOCR",
|
|
38
|
+
filename="README.md",
|
|
39
|
+
repo_type=None,
|
|
40
|
+
is_snapshot=True,
|
|
41
|
+
wanna_dir_path=True,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def get_yolo_path(self) -> Path:
|
|
45
|
+
return self._get_model_path(
|
|
46
|
+
repo_id="opendatalab/PDF-Extract-Kit-1.0",
|
|
47
|
+
filename="models/Layout/YOLO/doclayout_yolo_ft.pt",
|
|
48
|
+
repo_type=None,
|
|
49
|
+
is_snapshot=False,
|
|
50
|
+
wanna_dir_path=False,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def get_layoutreader_path(self) -> Path:
|
|
54
|
+
return self._get_model_path(
|
|
55
|
+
repo_id="hantian/layoutreader",
|
|
56
|
+
filename="model.safetensors",
|
|
57
|
+
repo_type=None,
|
|
58
|
+
is_snapshot=True,
|
|
59
|
+
wanna_dir_path=True,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def get_struct_eqtable_path(self) -> Path:
|
|
63
|
+
return self._get_model_path(
|
|
64
|
+
repo_id="U4R/StructTable-InternVL2-1B",
|
|
65
|
+
filename="model.safetensors",
|
|
66
|
+
repo_type=None,
|
|
67
|
+
is_snapshot=True,
|
|
68
|
+
wanna_dir_path=True,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def get_latex_path(self) -> Path:
|
|
72
|
+
return self._get_model_path(
|
|
73
|
+
repo_id="lukbl/LaTeX-OCR",
|
|
74
|
+
filename="checkpoints/weights.pth",
|
|
75
|
+
repo_type="space",
|
|
76
|
+
is_snapshot=True,
|
|
77
|
+
wanna_dir_path=True,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def _get_model_path(
|
|
81
|
+
self,
|
|
82
|
+
repo_id: str,
|
|
83
|
+
filename: str,
|
|
84
|
+
repo_type: str | None,
|
|
85
|
+
is_snapshot: bool,
|
|
86
|
+
wanna_dir_path: bool,
|
|
87
|
+
) -> Path:
|
|
88
|
+
|
|
89
|
+
with self._lock:
|
|
90
|
+
model_path = try_to_load_from_cache(
|
|
91
|
+
repo_id=repo_id,
|
|
92
|
+
filename=filename,
|
|
93
|
+
repo_type=repo_type,
|
|
94
|
+
cache_dir=self._model_cache_dir
|
|
95
|
+
)
|
|
96
|
+
if isinstance(model_path, str):
|
|
97
|
+
model_path = Path(model_path)
|
|
98
|
+
if wanna_dir_path:
|
|
99
|
+
for _ in Path(filename).parts:
|
|
100
|
+
model_path = model_path.parent
|
|
101
|
+
|
|
102
|
+
else:
|
|
103
|
+
# https://github.com/huggingface/huggingface_hub/issues/1542#issuecomment-1630465844
|
|
104
|
+
latest_error: ConnectionError | None = None
|
|
105
|
+
for i in range(_RETRY_TIMES + 1):
|
|
106
|
+
if latest_error is not None:
|
|
107
|
+
print(f"Retrying to download {repo_id} model, attempt {i + 1}/{_RETRY_TIMES}...")
|
|
108
|
+
sleep(_RETRY_SLEEP)
|
|
109
|
+
try:
|
|
110
|
+
if is_snapshot:
|
|
111
|
+
model_path = snapshot_download(
|
|
112
|
+
cache_dir=self._model_cache_dir,
|
|
113
|
+
repo_id=repo_id,
|
|
114
|
+
repo_type=repo_type,
|
|
115
|
+
resume_download=True,
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
model_path = hf_hub_download(
|
|
119
|
+
cache_dir=self._model_cache_dir,
|
|
120
|
+
repo_id=repo_id,
|
|
121
|
+
repo_type=repo_type,
|
|
122
|
+
filename=filename,
|
|
123
|
+
resume_download=True,
|
|
124
|
+
)
|
|
125
|
+
latest_error = None
|
|
126
|
+
except ConnectionError as err:
|
|
127
|
+
latest_error = err
|
|
128
|
+
|
|
129
|
+
if latest_error is not None:
|
|
130
|
+
raise latest_error
|
|
131
|
+
model_path = Path(cast(PathLike, model_path))
|
|
132
|
+
|
|
133
|
+
return model_path
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import cv2
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from typing import cast, Any, Iterable, Literal, Generator
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from .onnxocr import TextSystem
|
|
8
|
+
from .types import OCRFragment
|
|
9
|
+
from .model import Model
|
|
10
|
+
from .rectangle import Rectangle
|
|
11
|
+
from .utils import is_space_text
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_MODELS = (
|
|
15
|
+
("ppocrv4", "rec", "rec.onnx"),
|
|
16
|
+
("ppocrv4", "cls", "cls.onnx"),
|
|
17
|
+
("ppocrv4", "det", "det.onnx"),
|
|
18
|
+
("ch_ppocr_server_v2.0", "ppocr_keys_v1.txt"),
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class _OONXParams:
|
|
23
|
+
use_angle_cls: bool
|
|
24
|
+
use_gpu: bool
|
|
25
|
+
rec_image_shape: tuple[int, int, int]
|
|
26
|
+
cls_image_shape: tuple[int, int, int]
|
|
27
|
+
cls_batch_num: int
|
|
28
|
+
cls_thresh: float
|
|
29
|
+
label_list: list[str]
|
|
30
|
+
|
|
31
|
+
det_algorithm: str
|
|
32
|
+
det_limit_side_len: int
|
|
33
|
+
det_limit_type: str
|
|
34
|
+
det_db_thresh: float
|
|
35
|
+
det_db_box_thresh: float
|
|
36
|
+
det_db_unclip_ratio: float
|
|
37
|
+
use_dilation: bool
|
|
38
|
+
det_db_score_mode: str
|
|
39
|
+
det_box_type: str
|
|
40
|
+
rec_batch_num: int
|
|
41
|
+
drop_score: float
|
|
42
|
+
save_crop_res: bool
|
|
43
|
+
rec_algorithm: str
|
|
44
|
+
use_space_char: bool
|
|
45
|
+
rec_model_dir: str
|
|
46
|
+
cls_model_dir: str
|
|
47
|
+
det_model_dir: str
|
|
48
|
+
rec_char_dict_path: str
|
|
49
|
+
|
|
50
|
+
class OCR:
|
|
51
|
+
def __init__(self, device: Literal["cpu", "cuda"], model: Model):
|
|
52
|
+
self._device: Literal["cpu", "cuda"] = device
|
|
53
|
+
self._model: Model = model
|
|
54
|
+
self._text_system: TextSystem | None = None
|
|
55
|
+
|
|
56
|
+
def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
|
|
57
|
+
for box, res in self._ocr(image):
|
|
58
|
+
text, rank = res
|
|
59
|
+
if is_space_text(text):
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
rect = Rectangle(
|
|
63
|
+
lt=(box[0][0], box[0][1]),
|
|
64
|
+
rt=(box[1][0], box[1][1]),
|
|
65
|
+
rb=(box[2][0], box[2][1]),
|
|
66
|
+
lb=(box[3][0], box[3][1]),
|
|
67
|
+
)
|
|
68
|
+
if not rect.is_valid or rect.area == 0.0:
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
yield OCRFragment(
|
|
72
|
+
order=0,
|
|
73
|
+
text=text,
|
|
74
|
+
rank=rank,
|
|
75
|
+
rect=rect,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _ocr(self, image: np.ndarray) -> Generator[tuple[list[list[float]], tuple[str, float]], None, None]:
|
|
79
|
+
text_system = self._get_text_system()
|
|
80
|
+
image = self._preprocess_image(image)
|
|
81
|
+
dt_boxes, rec_res = text_system(image)
|
|
82
|
+
|
|
83
|
+
for box, res in zip(
|
|
84
|
+
cast(Iterable[Any], dt_boxes),
|
|
85
|
+
cast(Iterable[Any], rec_res),
|
|
86
|
+
):
|
|
87
|
+
yield box.tolist(), res
|
|
88
|
+
|
|
89
|
+
def _get_text_system(self) -> TextSystem:
|
|
90
|
+
if self._text_system is None:
|
|
91
|
+
model_paths = self._make_model_paths()
|
|
92
|
+
self._text_system = TextSystem(_OONXParams(
|
|
93
|
+
use_angle_cls=True,
|
|
94
|
+
use_gpu=(self._device != "cpu"),
|
|
95
|
+
rec_image_shape=(3, 48, 320),
|
|
96
|
+
cls_image_shape=(3, 48, 192),
|
|
97
|
+
cls_batch_num=6,
|
|
98
|
+
cls_thresh=0.9,
|
|
99
|
+
label_list=["0", "180"],
|
|
100
|
+
det_algorithm="DB",
|
|
101
|
+
det_limit_side_len=960,
|
|
102
|
+
det_limit_type="max",
|
|
103
|
+
det_db_thresh=0.3,
|
|
104
|
+
det_db_box_thresh=0.6,
|
|
105
|
+
det_db_unclip_ratio=1.5,
|
|
106
|
+
use_dilation=False,
|
|
107
|
+
det_db_score_mode="fast",
|
|
108
|
+
det_box_type="quad",
|
|
109
|
+
rec_batch_num=6,
|
|
110
|
+
drop_score=0.5,
|
|
111
|
+
save_crop_res=False,
|
|
112
|
+
rec_algorithm="SVTR_LCNet",
|
|
113
|
+
use_space_char=True,
|
|
114
|
+
rec_model_dir=model_paths[0],
|
|
115
|
+
cls_model_dir=model_paths[1],
|
|
116
|
+
det_model_dir=model_paths[2],
|
|
117
|
+
rec_char_dict_path=model_paths[3],
|
|
118
|
+
))
|
|
119
|
+
return self._text_system
|
|
120
|
+
|
|
121
|
+
def _make_model_paths(self) -> list[str]:
|
|
122
|
+
model_paths: list[str] = []
|
|
123
|
+
model_dir = self._model.get_onnx_ocr_path()
|
|
124
|
+
for model_path in _MODELS:
|
|
125
|
+
file_name = os.path.join(*model_path)
|
|
126
|
+
model_paths.append(str(model_dir / file_name))
|
|
127
|
+
return model_paths
|
|
128
|
+
|
|
129
|
+
def _preprocess_image(self, np_image: np.ndarray) -> np.ndarray:
|
|
130
|
+
image = self._alpha_to_color(np_image, (255, 255, 255))
|
|
131
|
+
# image = cv2.bitwise_not(image) # inv
|
|
132
|
+
# image = self._binarize_img(image) # bin
|
|
133
|
+
image = cv2.normalize(
|
|
134
|
+
src=image,
|
|
135
|
+
dst=np.zeros((image.shape[0], image.shape[1])),
|
|
136
|
+
alpha=0,
|
|
137
|
+
beta=255,
|
|
138
|
+
norm_type=cv2.NORM_MINMAX,
|
|
139
|
+
)
|
|
140
|
+
if cv2.cuda.getCudaEnabledDeviceCount() > 0:
|
|
141
|
+
gpu_frame = cv2.cuda.GpuMat()
|
|
142
|
+
gpu_frame.upload(image)
|
|
143
|
+
image = cv2.cuda.fastNlMeansDenoisingColored(
|
|
144
|
+
src=gpu_frame,
|
|
145
|
+
dst=None,
|
|
146
|
+
h_luminance=10,
|
|
147
|
+
photo_render=10,
|
|
148
|
+
search_window=15,
|
|
149
|
+
block_size=7,
|
|
150
|
+
)
|
|
151
|
+
image = gpu_frame.download()
|
|
152
|
+
elif cv2.ocl.haveOpenCL():
|
|
153
|
+
cv2.ocl.setUseOpenCL(True)
|
|
154
|
+
gpu_frame = cv2.UMat(cast(Any, image))
|
|
155
|
+
image = cv2.fastNlMeansDenoisingColored(
|
|
156
|
+
src=gpu_frame,
|
|
157
|
+
dst=None,
|
|
158
|
+
h=10,
|
|
159
|
+
hColor=10,
|
|
160
|
+
templateWindowSize=7,
|
|
161
|
+
searchWindowSize=15,
|
|
162
|
+
)
|
|
163
|
+
image = image.get()
|
|
164
|
+
else:
|
|
165
|
+
image = cv2.fastNlMeansDenoisingColored(
|
|
166
|
+
src=image,
|
|
167
|
+
dst=None,
|
|
168
|
+
h=10,
|
|
169
|
+
hColor=10,
|
|
170
|
+
templateWindowSize=7,
|
|
171
|
+
searchWindowSize=15,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # image to gray
|
|
175
|
+
return image
|
|
176
|
+
|
|
177
|
+
def _alpha_to_color(self, image: np.ndarray, alpha_color: tuple[float, float, float]) -> np.ndarray:
|
|
178
|
+
if len(image.shape) == 3 and image.shape[2] == 4:
|
|
179
|
+
B, G, R, A = cv2.split(image)
|
|
180
|
+
alpha = A / 255
|
|
181
|
+
|
|
182
|
+
R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
|
|
183
|
+
G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
|
|
184
|
+
B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
|
|
185
|
+
|
|
186
|
+
image = cv2.merge((B, G, R))
|
|
187
|
+
|
|
188
|
+
return image
|
|
189
|
+
|
|
190
|
+
def _binarize_img(self, image: np.ndarray):
|
|
191
|
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
|
192
|
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # conversion to grayscale image
|
|
193
|
+
# use cv2 threshold binarization
|
|
194
|
+
_, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
195
|
+
image = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
|
196
|
+
return image
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from typing import cast, Iterable
|
|
4
|
+
from shapely.geometry import Polygon
|
|
5
|
+
from PIL.Image import new, Image, Resampling
|
|
6
|
+
from .types import Layout, OCRFragment
|
|
7
|
+
from .ocr import OCR
|
|
8
|
+
from .overlap import overlap_rate
|
|
9
|
+
from .rectangle import Point, Rectangle
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_MIN_RATE = 0.5
|
|
13
|
+
|
|
14
|
+
def correct_fragments(ocr: OCR, source: Image, layout: Layout):
|
|
15
|
+
x1, y1, x2, y2 = layout.rect.wrapper
|
|
16
|
+
image: Image = source.crop((
|
|
17
|
+
round(x1), round(y1),
|
|
18
|
+
round(x2), round(y2),
|
|
19
|
+
))
|
|
20
|
+
image, dx, dy, scale = _adjust_image(image)
|
|
21
|
+
image_np = np.array(image)
|
|
22
|
+
ocr_fragments = list(ocr.search_fragments(image_np))
|
|
23
|
+
corrected_fragments: list[OCRFragment] = []
|
|
24
|
+
|
|
25
|
+
for fragment in ocr_fragments:
|
|
26
|
+
_apply_fragment(fragment.rect, layout, dx, dy, scale)
|
|
27
|
+
|
|
28
|
+
matched_fragments, not_matched_fragments = _match_fragments(
|
|
29
|
+
zone_rect=layout.rect,
|
|
30
|
+
fragments1=layout.fragments,
|
|
31
|
+
fragments2=ocr_fragments,
|
|
32
|
+
)
|
|
33
|
+
for fragment1, fragment2 in matched_fragments:
|
|
34
|
+
if fragment1.rank > fragment2.rank:
|
|
35
|
+
corrected_fragments.append(fragment1)
|
|
36
|
+
else:
|
|
37
|
+
corrected_fragments.append(fragment2)
|
|
38
|
+
|
|
39
|
+
corrected_fragments.extend(not_matched_fragments)
|
|
40
|
+
layout.fragments = corrected_fragments
|
|
41
|
+
|
|
42
|
+
def _adjust_image(image: Image) -> tuple[Image, int, int, float]:
|
|
43
|
+
# after testing, adding white borders to images can reduce
|
|
44
|
+
# the possibility of some text not being recognized
|
|
45
|
+
border_size: int = 50
|
|
46
|
+
adjusted_size: int = 1024 - 2 * border_size
|
|
47
|
+
width, height = image.size
|
|
48
|
+
core_width = float(max(adjusted_size, width))
|
|
49
|
+
core_height = float(max(adjusted_size, height))
|
|
50
|
+
|
|
51
|
+
scale_x = core_width / width
|
|
52
|
+
scale_y = core_height / height
|
|
53
|
+
scale = min(scale_x, scale_y)
|
|
54
|
+
adjusted_width = width * scale
|
|
55
|
+
adjusted_height = height * scale
|
|
56
|
+
|
|
57
|
+
dx = (core_width - adjusted_width) / 2.0
|
|
58
|
+
dy = (core_height - adjusted_height) / 2.0
|
|
59
|
+
dx = round(dx) + border_size
|
|
60
|
+
dy = round(dy) + border_size
|
|
61
|
+
|
|
62
|
+
if scale != 1.0:
|
|
63
|
+
width = round(width * scale)
|
|
64
|
+
height = round(height * scale)
|
|
65
|
+
image = image.resize((width, height), Resampling.BICUBIC)
|
|
66
|
+
|
|
67
|
+
width = round(core_width) + 2 * border_size
|
|
68
|
+
height = round(core_height) + 2 * border_size
|
|
69
|
+
new_image = new("RGB", (width, height), (255, 255, 255))
|
|
70
|
+
new_image.paste(image, (dx, dy))
|
|
71
|
+
|
|
72
|
+
return new_image, dx, dy, scale
|
|
73
|
+
|
|
74
|
+
def _apply_fragment(rect: Rectangle, layout: Layout, dx: int, dy: int, scale: float):
|
|
75
|
+
rect.lt = _apply_point(rect.lt, layout, dx, dy, scale)
|
|
76
|
+
rect.lb = _apply_point(rect.lb, layout, dx, dy, scale)
|
|
77
|
+
rect.rb = _apply_point(rect.rb, layout, dx, dy, scale)
|
|
78
|
+
rect.rt = _apply_point(rect.rt, layout, dx, dy, scale)
|
|
79
|
+
|
|
80
|
+
def _apply_point(point: Point, layout: Layout, dx: int, dy: int, scale: float) -> Point:
|
|
81
|
+
x, y = point
|
|
82
|
+
x = (x - dx) / scale + layout.rect.lt[0]
|
|
83
|
+
y = (y - dy) / scale + layout.rect.lt[1]
|
|
84
|
+
return x, y
|
|
85
|
+
|
|
86
|
+
def _match_fragments(
|
|
87
|
+
zone_rect: Rectangle,
|
|
88
|
+
fragments1: Iterable[OCRFragment],
|
|
89
|
+
fragments2: Iterable[OCRFragment],
|
|
90
|
+
) -> tuple[list[tuple[OCRFragment, OCRFragment]], list[OCRFragment]]:
|
|
91
|
+
|
|
92
|
+
zone_polygon = Polygon(zone_rect)
|
|
93
|
+
fragments2 = list(fragments2)
|
|
94
|
+
matched_fragments: list[tuple[OCRFragment, OCRFragment]] = []
|
|
95
|
+
not_matched_fragments: list[OCRFragment] = []
|
|
96
|
+
|
|
97
|
+
for fragment1 in fragments1:
|
|
98
|
+
polygon1 = Polygon(fragment1.rect)
|
|
99
|
+
polygon1 = cast(Polygon, zone_polygon.intersection(polygon1))
|
|
100
|
+
if polygon1.is_empty:
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
beast_j = -1
|
|
104
|
+
beast_rate = 0.0
|
|
105
|
+
|
|
106
|
+
for j, fragment2 in enumerate(fragments2):
|
|
107
|
+
polygon2 = Polygon(fragment2.rect)
|
|
108
|
+
rate = overlap_rate(polygon1, polygon2)
|
|
109
|
+
if rate < _MIN_RATE:
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
if rate > beast_rate:
|
|
113
|
+
beast_j = j
|
|
114
|
+
beast_rate = rate
|
|
115
|
+
|
|
116
|
+
if beast_j != -1:
|
|
117
|
+
matched_fragments.append((
|
|
118
|
+
fragment1,
|
|
119
|
+
fragments2[beast_j],
|
|
120
|
+
))
|
|
121
|
+
del fragments2[beast_j]
|
|
122
|
+
else:
|
|
123
|
+
not_matched_fragments.append(fragment1)
|
|
124
|
+
|
|
125
|
+
not_matched_fragments.extend(fragments2)
|
|
126
|
+
return matched_fragments, not_matched_fragments
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .predict_system import TextSystem
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
class ClsPostProcess (object):
|
|
2
|
+
""" Convert between text-label and text-index """
|
|
3
|
+
|
|
4
|
+
def __init__(self, label_list=None, key=None, **kwargs):
|
|
5
|
+
super(ClsPostProcess, self).__init__()
|
|
6
|
+
self.label_list = label_list
|
|
7
|
+
self.key = key
|
|
8
|
+
|
|
9
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
10
|
+
if self.key is not None:
|
|
11
|
+
preds = preds[self.key]
|
|
12
|
+
|
|
13
|
+
label_list = self.label_list
|
|
14
|
+
if label_list is None:
|
|
15
|
+
label_list = {idx: idx for idx in range(preds.shape[-1])}
|
|
16
|
+
|
|
17
|
+
# if isinstance(preds, paddle.Tensor):
|
|
18
|
+
# preds = preds.numpy()
|
|
19
|
+
|
|
20
|
+
pred_idxs = preds.argmax(axis=1)
|
|
21
|
+
decode_out = [(label_list[idx], preds[i, idx])
|
|
22
|
+
for i, idx in enumerate(pred_idxs)]
|
|
23
|
+
if label is None:
|
|
24
|
+
return decode_out
|
|
25
|
+
label = [(label_list[idx], 1.0) for idx in label]
|
|
26
|
+
return decode_out, label
|