doc-page-extractor 0.2.0__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc_page_extractor/__init__.py +5 -15
- doc_page_extractor/check_env.py +40 -0
- doc_page_extractor/extractor.py +88 -215
- doc_page_extractor/model.py +97 -0
- doc_page_extractor/parser.py +51 -0
- doc_page_extractor/plot.py +52 -79
- doc_page_extractor/redacter.py +111 -0
- doc_page_extractor-1.0.2.dist-info/METADATA +120 -0
- doc_page_extractor-1.0.2.dist-info/RECORD +11 -0
- {doc_page_extractor-0.2.0.dist-info → doc_page_extractor-1.0.2.dist-info}/WHEEL +1 -2
- doc_page_extractor-1.0.2.dist-info/licenses/LICENSE +21 -0
- doc_page_extractor/clipper.py +0 -119
- doc_page_extractor/downloader.py +0 -16
- doc_page_extractor/latex.py +0 -31
- doc_page_extractor/layout_order.py +0 -237
- doc_page_extractor/layoutreader.py +0 -126
- doc_page_extractor/models.py +0 -92
- doc_page_extractor/ocr.py +0 -200
- doc_page_extractor/ocr_corrector.py +0 -126
- doc_page_extractor/onnxocr/__init__.py +0 -1
- doc_page_extractor/onnxocr/cls_postprocess.py +0 -26
- doc_page_extractor/onnxocr/db_postprocess.py +0 -246
- doc_page_extractor/onnxocr/imaug.py +0 -32
- doc_page_extractor/onnxocr/operators.py +0 -187
- doc_page_extractor/onnxocr/predict_base.py +0 -57
- doc_page_extractor/onnxocr/predict_cls.py +0 -109
- doc_page_extractor/onnxocr/predict_det.py +0 -139
- doc_page_extractor/onnxocr/predict_rec.py +0 -344
- doc_page_extractor/onnxocr/predict_system.py +0 -97
- doc_page_extractor/onnxocr/rec_postprocess.py +0 -896
- doc_page_extractor/onnxocr/utils.py +0 -71
- doc_page_extractor/overlap.py +0 -167
- doc_page_extractor/raw_optimizer.py +0 -104
- doc_page_extractor/rectangle.py +0 -72
- doc_page_extractor/rotation.py +0 -158
- doc_page_extractor/struct_eqtable/__init__.py +0 -49
- doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
- doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
- doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
- doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
- doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
- doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
- doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
- doc_page_extractor/table.py +0 -70
- doc_page_extractor/types.py +0 -91
- doc_page_extractor/utils.py +0 -32
- doc_page_extractor-0.2.0.dist-info/METADATA +0 -85
- doc_page_extractor-0.2.0.dist-info/RECORD +0 -45
- doc_page_extractor-0.2.0.dist-info/licenses/LICENSE +0 -661
- doc_page_extractor-0.2.0.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -0
- tests/test_history_bus.py +0 -55
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
# Copy from https://github.com/ppaanngggg/layoutreader/blob/main/v3/helpers.py
|
|
2
|
-
from collections import defaultdict
|
|
3
|
-
from typing import List, Dict
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from transformers import LayoutLMv3ForTokenClassification
|
|
7
|
-
|
|
8
|
-
MAX_LEN = 510
|
|
9
|
-
CLS_TOKEN_ID = 0
|
|
10
|
-
UNK_TOKEN_ID = 3
|
|
11
|
-
EOS_TOKEN_ID = 2
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class DataCollator:
|
|
15
|
-
def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
|
|
16
|
-
bbox = []
|
|
17
|
-
labels = []
|
|
18
|
-
input_ids = []
|
|
19
|
-
attention_mask = []
|
|
20
|
-
|
|
21
|
-
# clip bbox and labels to max length, build input_ids and attention_mask
|
|
22
|
-
for feature in features:
|
|
23
|
-
_bbox = feature["source_boxes"]
|
|
24
|
-
if len(_bbox) > MAX_LEN:
|
|
25
|
-
_bbox = _bbox[:MAX_LEN]
|
|
26
|
-
_labels = feature["target_index"]
|
|
27
|
-
if len(_labels) > MAX_LEN:
|
|
28
|
-
_labels = _labels[:MAX_LEN]
|
|
29
|
-
_input_ids = [UNK_TOKEN_ID] * len(_bbox)
|
|
30
|
-
_attention_mask = [1] * len(_bbox)
|
|
31
|
-
assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
|
|
32
|
-
bbox.append(_bbox)
|
|
33
|
-
labels.append(_labels)
|
|
34
|
-
input_ids.append(_input_ids)
|
|
35
|
-
attention_mask.append(_attention_mask)
|
|
36
|
-
|
|
37
|
-
# add CLS and EOS tokens
|
|
38
|
-
for i in range(len(bbox)):
|
|
39
|
-
bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
|
|
40
|
-
labels[i] = [-100] + labels[i] + [-100]
|
|
41
|
-
input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
|
|
42
|
-
attention_mask[i] = [1] + attention_mask[i] + [1]
|
|
43
|
-
|
|
44
|
-
# padding to max length
|
|
45
|
-
max_len = max(len(x) for x in bbox)
|
|
46
|
-
for i in range(len(bbox)):
|
|
47
|
-
bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
|
|
48
|
-
labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
|
|
49
|
-
input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
|
|
50
|
-
attention_mask[i] = attention_mask[i] + [0] * (
|
|
51
|
-
max_len - len(attention_mask[i])
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
ret = {
|
|
55
|
-
"bbox": torch.tensor(bbox),
|
|
56
|
-
"attention_mask": torch.tensor(attention_mask),
|
|
57
|
-
"labels": torch.tensor(labels),
|
|
58
|
-
"input_ids": torch.tensor(input_ids),
|
|
59
|
-
}
|
|
60
|
-
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
|
|
61
|
-
ret["labels"][ret["labels"] > MAX_LEN] = -100
|
|
62
|
-
# set label > 0 to label-1, because original labels are 1-indexed
|
|
63
|
-
ret["labels"][ret["labels"] > 0] -= 1
|
|
64
|
-
return ret
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
|
|
68
|
-
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
|
|
69
|
-
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
|
|
70
|
-
attention_mask = [1] + [1] * len(boxes) + [1]
|
|
71
|
-
return {
|
|
72
|
-
"bbox": torch.tensor([bbox]),
|
|
73
|
-
"attention_mask": torch.tensor([attention_mask]),
|
|
74
|
-
"input_ids": torch.tensor([input_ids]),
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def prepare_inputs(
|
|
79
|
-
inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
|
|
80
|
-
) -> Dict[str, torch.Tensor]:
|
|
81
|
-
ret = {}
|
|
82
|
-
for k, v in inputs.items():
|
|
83
|
-
v = v.to(model.device)
|
|
84
|
-
if torch.is_floating_point(v):
|
|
85
|
-
v = v.to(model.dtype)
|
|
86
|
-
ret[k] = v
|
|
87
|
-
return ret
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
|
|
91
|
-
"""
|
|
92
|
-
parse logits to orders
|
|
93
|
-
|
|
94
|
-
:param logits: logits from model
|
|
95
|
-
:param length: input length
|
|
96
|
-
:return: orders
|
|
97
|
-
"""
|
|
98
|
-
logits = logits[1 : length + 1, :length]
|
|
99
|
-
orders = logits.argsort(descending=False).tolist()
|
|
100
|
-
ret = [o.pop() for o in orders]
|
|
101
|
-
while True:
|
|
102
|
-
order_to_idxes = defaultdict(list)
|
|
103
|
-
for idx, order in enumerate(ret):
|
|
104
|
-
order_to_idxes[order].append(idx)
|
|
105
|
-
# filter idxes len > 1
|
|
106
|
-
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
|
|
107
|
-
if not order_to_idxes:
|
|
108
|
-
break
|
|
109
|
-
# filter
|
|
110
|
-
for order, idxes in order_to_idxes.items():
|
|
111
|
-
# find original logits of idxes
|
|
112
|
-
idxes_to_logit = {}
|
|
113
|
-
for idx in idxes:
|
|
114
|
-
idxes_to_logit[idx] = logits[idx, order]
|
|
115
|
-
idxes_to_logit = sorted(
|
|
116
|
-
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
|
|
117
|
-
)
|
|
118
|
-
# keep the highest logit as order, set others to next candidate
|
|
119
|
-
for idx, _ in idxes_to_logit[1:]:
|
|
120
|
-
ret[idx] = orders[idx].pop()
|
|
121
|
-
|
|
122
|
-
return ret
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def check_duplicate(a: List[int]) -> bool:
|
|
126
|
-
return len(a) != len(set(a))
|
doc_page_extractor/models.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
|
-
from logging import Logger
|
|
4
|
-
from huggingface_hub import hf_hub_download, snapshot_download, try_to_load_from_cache
|
|
5
|
-
from .types import ModelsDownloader
|
|
6
|
-
|
|
7
|
-
class HuggingfaceModelsDownloader(ModelsDownloader):
|
|
8
|
-
def __init__(
|
|
9
|
-
self,
|
|
10
|
-
logger: Logger,
|
|
11
|
-
model_dir_path: str | None
|
|
12
|
-
):
|
|
13
|
-
self._logger = logger
|
|
14
|
-
self._model_dir_path: str | None = model_dir_path
|
|
15
|
-
|
|
16
|
-
def onnx_ocr(self) -> str:
|
|
17
|
-
repo_path = try_to_load_from_cache(
|
|
18
|
-
repo_id="moskize/OnnxOCR",
|
|
19
|
-
filename="README.md",
|
|
20
|
-
cache_dir=self._model_dir_path
|
|
21
|
-
)
|
|
22
|
-
if isinstance(repo_path, str):
|
|
23
|
-
return os.path.dirname(repo_path)
|
|
24
|
-
else:
|
|
25
|
-
self._logger.info("Downloading OCR model...")
|
|
26
|
-
return snapshot_download(
|
|
27
|
-
cache_dir=self._model_dir_path,
|
|
28
|
-
repo_id="moskize/OnnxOCR",
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
def yolo(self) -> str:
|
|
32
|
-
yolo_file_path = try_to_load_from_cache(
|
|
33
|
-
repo_id="opendatalab/PDF-Extract-Kit-1.0",
|
|
34
|
-
filename="models/Layout/YOLO/doclayout_yolo_ft.pt",
|
|
35
|
-
cache_dir=self._model_dir_path
|
|
36
|
-
)
|
|
37
|
-
if isinstance(yolo_file_path, str):
|
|
38
|
-
return yolo_file_path
|
|
39
|
-
else:
|
|
40
|
-
self._logger.info("Downloading YOLO model...")
|
|
41
|
-
return hf_hub_download(
|
|
42
|
-
cache_dir=self._model_dir_path,
|
|
43
|
-
repo_id="opendatalab/PDF-Extract-Kit-1.0",
|
|
44
|
-
filename="models/Layout/YOLO/doclayout_yolo_ft.pt",
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
def layoutreader(self) -> str:
|
|
48
|
-
repo_path = try_to_load_from_cache(
|
|
49
|
-
repo_id="hantian/layoutreader",
|
|
50
|
-
filename="model.safetensors",
|
|
51
|
-
cache_dir=self._model_dir_path
|
|
52
|
-
)
|
|
53
|
-
if isinstance(repo_path, str):
|
|
54
|
-
return os.path.dirname(repo_path)
|
|
55
|
-
else:
|
|
56
|
-
self._logger.info("Downloading LayoutReader model...")
|
|
57
|
-
return snapshot_download(
|
|
58
|
-
cache_dir=self._model_dir_path,
|
|
59
|
-
repo_id="hantian/layoutreader",
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
def struct_eqtable(self) -> str:
|
|
63
|
-
repo_path = try_to_load_from_cache(
|
|
64
|
-
repo_id="U4R/StructTable-InternVL2-1B",
|
|
65
|
-
filename="model.safetensors",
|
|
66
|
-
cache_dir=self._model_dir_path
|
|
67
|
-
)
|
|
68
|
-
if isinstance(repo_path, str):
|
|
69
|
-
return os.path.dirname(repo_path)
|
|
70
|
-
else:
|
|
71
|
-
self._logger.info("Downloading StructEqTable model...")
|
|
72
|
-
return snapshot_download(
|
|
73
|
-
cache_dir=self._model_dir_path,
|
|
74
|
-
repo_id="U4R/StructTable-InternVL2-1B",
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
def latex(self):
|
|
78
|
-
repo_path = try_to_load_from_cache(
|
|
79
|
-
repo_id="lukbl/LaTeX-OCR",
|
|
80
|
-
filename="checkpoints/weights.pth",
|
|
81
|
-
repo_type="space",
|
|
82
|
-
cache_dir=self._model_dir_path
|
|
83
|
-
)
|
|
84
|
-
if isinstance(repo_path, str):
|
|
85
|
-
return os.path.dirname(os.path.dirname(repo_path))
|
|
86
|
-
else:
|
|
87
|
-
self._logger.info("Downloading LaTeX model...")
|
|
88
|
-
return snapshot_download(
|
|
89
|
-
cache_dir=self._model_dir_path,
|
|
90
|
-
repo_type="space",
|
|
91
|
-
repo_id="lukbl/LaTeX-OCR",
|
|
92
|
-
)
|
doc_page_extractor/ocr.py
DELETED
|
@@ -1,200 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import cv2
|
|
3
|
-
import os
|
|
4
|
-
|
|
5
|
-
from typing import Literal, Generator
|
|
6
|
-
from dataclasses import dataclass
|
|
7
|
-
from .onnxocr import TextSystem
|
|
8
|
-
from .types import GetModelDir, OCRFragment
|
|
9
|
-
from .rectangle import Rectangle
|
|
10
|
-
from .utils import is_space_text
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
_MODELS = (
|
|
14
|
-
("ppocrv4", "rec", "rec.onnx"),
|
|
15
|
-
("ppocrv4", "cls", "cls.onnx"),
|
|
16
|
-
("ppocrv4", "det", "det.onnx"),
|
|
17
|
-
("ch_ppocr_server_v2.0", "ppocr_keys_v1.txt"),
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
@dataclass
|
|
21
|
-
class _OONXParams:
|
|
22
|
-
use_angle_cls: bool
|
|
23
|
-
use_gpu: bool
|
|
24
|
-
rec_image_shape: tuple[int, int, int]
|
|
25
|
-
cls_image_shape: tuple[int, int, int]
|
|
26
|
-
cls_batch_num: int
|
|
27
|
-
cls_thresh: float
|
|
28
|
-
label_list: list[str]
|
|
29
|
-
|
|
30
|
-
det_algorithm: str
|
|
31
|
-
det_limit_side_len: int
|
|
32
|
-
det_limit_type: str
|
|
33
|
-
det_db_thresh: float
|
|
34
|
-
det_db_box_thresh: float
|
|
35
|
-
det_db_unclip_ratio: float
|
|
36
|
-
use_dilation: bool
|
|
37
|
-
det_db_score_mode: str
|
|
38
|
-
det_box_type: str
|
|
39
|
-
rec_batch_num: int
|
|
40
|
-
drop_score: float
|
|
41
|
-
save_crop_res: bool
|
|
42
|
-
rec_algorithm: str
|
|
43
|
-
use_space_char: bool
|
|
44
|
-
rec_model_dir: str
|
|
45
|
-
cls_model_dir: str
|
|
46
|
-
det_model_dir: str
|
|
47
|
-
rec_char_dict_path: str
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class OCR:
|
|
53
|
-
def __init__(
|
|
54
|
-
self,
|
|
55
|
-
device: Literal["cpu", "cuda"],
|
|
56
|
-
get_model_dir: GetModelDir,
|
|
57
|
-
):
|
|
58
|
-
self._device: Literal["cpu", "cuda"] = device
|
|
59
|
-
self._get_model_dir: GetModelDir = get_model_dir
|
|
60
|
-
self._text_system: TextSystem | None = None
|
|
61
|
-
|
|
62
|
-
def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
|
|
63
|
-
for box, res in self._ocr(image):
|
|
64
|
-
text, rank = res
|
|
65
|
-
if is_space_text(text):
|
|
66
|
-
continue
|
|
67
|
-
|
|
68
|
-
rect = Rectangle(
|
|
69
|
-
lt=(box[0][0], box[0][1]),
|
|
70
|
-
rt=(box[1][0], box[1][1]),
|
|
71
|
-
rb=(box[2][0], box[2][1]),
|
|
72
|
-
lb=(box[3][0], box[3][1]),
|
|
73
|
-
)
|
|
74
|
-
if not rect.is_valid or rect.area == 0.0:
|
|
75
|
-
continue
|
|
76
|
-
|
|
77
|
-
yield OCRFragment(
|
|
78
|
-
order=0,
|
|
79
|
-
text=text,
|
|
80
|
-
rank=rank,
|
|
81
|
-
rect=rect,
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
def _ocr(self, image: np.ndarray) -> Generator[tuple[list[list[float]], tuple[str, float]], None, None]:
|
|
85
|
-
text_system = self._get_text_system()
|
|
86
|
-
image = self._preprocess_image(image)
|
|
87
|
-
dt_boxes, rec_res = text_system(image)
|
|
88
|
-
|
|
89
|
-
for box, res in zip(dt_boxes, rec_res):
|
|
90
|
-
yield box.tolist(), res
|
|
91
|
-
|
|
92
|
-
def make_model_paths(self) -> list[str]:
|
|
93
|
-
model_paths = []
|
|
94
|
-
model_dir = self._get_model_dir()
|
|
95
|
-
for model_path in _MODELS:
|
|
96
|
-
file_name = os.path.join(*model_path)
|
|
97
|
-
model_paths.append(os.path.join(model_dir, file_name))
|
|
98
|
-
return model_paths
|
|
99
|
-
|
|
100
|
-
def _get_text_system(self) -> TextSystem:
|
|
101
|
-
if self._text_system is None:
|
|
102
|
-
model_paths = self.make_model_paths()
|
|
103
|
-
self._text_system = TextSystem(_OONXParams(
|
|
104
|
-
use_angle_cls=True,
|
|
105
|
-
use_gpu=(self._device != "cpu"),
|
|
106
|
-
rec_image_shape=(3, 48, 320),
|
|
107
|
-
cls_image_shape=(3, 48, 192),
|
|
108
|
-
cls_batch_num=6,
|
|
109
|
-
cls_thresh=0.9,
|
|
110
|
-
label_list=["0", "180"],
|
|
111
|
-
det_algorithm="DB",
|
|
112
|
-
det_limit_side_len=960,
|
|
113
|
-
det_limit_type="max",
|
|
114
|
-
det_db_thresh=0.3,
|
|
115
|
-
det_db_box_thresh=0.6,
|
|
116
|
-
det_db_unclip_ratio=1.5,
|
|
117
|
-
use_dilation=False,
|
|
118
|
-
det_db_score_mode="fast",
|
|
119
|
-
det_box_type="quad",
|
|
120
|
-
rec_batch_num=6,
|
|
121
|
-
drop_score=0.5,
|
|
122
|
-
save_crop_res=False,
|
|
123
|
-
rec_algorithm="SVTR_LCNet",
|
|
124
|
-
use_space_char=True,
|
|
125
|
-
rec_model_dir=model_paths[0],
|
|
126
|
-
cls_model_dir=model_paths[1],
|
|
127
|
-
det_model_dir=model_paths[2],
|
|
128
|
-
rec_char_dict_path=model_paths[3],
|
|
129
|
-
))
|
|
130
|
-
|
|
131
|
-
return self._text_system
|
|
132
|
-
|
|
133
|
-
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
|
|
134
|
-
image = self._alpha_to_color(image, (255, 255, 255))
|
|
135
|
-
# image = cv2.bitwise_not(image) # inv
|
|
136
|
-
# image = self._binarize_img(image) # bin
|
|
137
|
-
image = cv2.normalize(
|
|
138
|
-
src=image,
|
|
139
|
-
dst=np.zeros((image.shape[0], image.shape[1])),
|
|
140
|
-
alpha=0,
|
|
141
|
-
beta=255,
|
|
142
|
-
norm_type=cv2.NORM_MINMAX,
|
|
143
|
-
)
|
|
144
|
-
if cv2.cuda.getCudaEnabledDeviceCount() > 0:
|
|
145
|
-
gpu_frame = cv2.cuda.GpuMat()
|
|
146
|
-
gpu_frame.upload(image)
|
|
147
|
-
image = cv2.cuda.fastNlMeansDenoisingColored(
|
|
148
|
-
src=gpu_frame,
|
|
149
|
-
dst=None,
|
|
150
|
-
h_luminance=10,
|
|
151
|
-
photo_render=10,
|
|
152
|
-
search_window=15,
|
|
153
|
-
block_size=7,
|
|
154
|
-
)
|
|
155
|
-
image = gpu_frame.download()
|
|
156
|
-
elif cv2.ocl.haveOpenCL():
|
|
157
|
-
cv2.ocl.setUseOpenCL(True)
|
|
158
|
-
gpu_frame = cv2.UMat(image)
|
|
159
|
-
image = cv2.fastNlMeansDenoisingColored(
|
|
160
|
-
src=gpu_frame,
|
|
161
|
-
dst=None,
|
|
162
|
-
h=10,
|
|
163
|
-
hColor=10,
|
|
164
|
-
templateWindowSize=7,
|
|
165
|
-
searchWindowSize=15,
|
|
166
|
-
)
|
|
167
|
-
image = image.get()
|
|
168
|
-
else:
|
|
169
|
-
image = cv2.fastNlMeansDenoisingColored(
|
|
170
|
-
src=image,
|
|
171
|
-
dst=None,
|
|
172
|
-
h=10,
|
|
173
|
-
hColor=10,
|
|
174
|
-
templateWindowSize=7,
|
|
175
|
-
searchWindowSize=15,
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # image to gray
|
|
179
|
-
return image
|
|
180
|
-
|
|
181
|
-
def _alpha_to_color(self, image: np.ndarray, alpha_color: tuple[float, float, float]) -> np.ndarray:
|
|
182
|
-
if len(image.shape) == 3 and image.shape[2] == 4:
|
|
183
|
-
B, G, R, A = cv2.split(image)
|
|
184
|
-
alpha = A / 255
|
|
185
|
-
|
|
186
|
-
R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
|
|
187
|
-
G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
|
|
188
|
-
B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
|
|
189
|
-
|
|
190
|
-
image = cv2.merge((B, G, R))
|
|
191
|
-
|
|
192
|
-
return image
|
|
193
|
-
|
|
194
|
-
def _binarize_img(self, image: np.ndarray):
|
|
195
|
-
if len(image.shape) == 3 and image.shape[2] == 3:
|
|
196
|
-
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # conversion to grayscale image
|
|
197
|
-
# use cv2 threshold binarization
|
|
198
|
-
_, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
199
|
-
image = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
|
200
|
-
return image
|
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
|
|
3
|
-
from typing import Iterable
|
|
4
|
-
from shapely.geometry import Polygon
|
|
5
|
-
from PIL.Image import new, Image, Resampling
|
|
6
|
-
from .types import Layout, OCRFragment
|
|
7
|
-
from .ocr import OCR
|
|
8
|
-
from .overlap import overlap_rate
|
|
9
|
-
from .rectangle import Point, Rectangle
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
_MIN_RATE = 0.5
|
|
13
|
-
|
|
14
|
-
def correct_fragments(ocr: OCR, source: Image, layout: Layout):
|
|
15
|
-
x1, y1, x2, y2 = layout.rect.wrapper
|
|
16
|
-
image: Image = source.crop((
|
|
17
|
-
round(x1), round(y1),
|
|
18
|
-
round(x2), round(y2),
|
|
19
|
-
))
|
|
20
|
-
image, dx, dy, scale = _adjust_image(image)
|
|
21
|
-
image_np = np.array(image)
|
|
22
|
-
ocr_fragments = list(ocr.search_fragments(image_np))
|
|
23
|
-
corrected_fragments: list[OCRFragment] = []
|
|
24
|
-
|
|
25
|
-
for fragment in ocr_fragments:
|
|
26
|
-
_apply_fragment(fragment.rect, layout, dx, dy, scale)
|
|
27
|
-
|
|
28
|
-
matched_fragments, not_matched_fragments = _match_fragments(
|
|
29
|
-
zone_rect=layout.rect,
|
|
30
|
-
fragments1=layout.fragments,
|
|
31
|
-
fragments2=ocr_fragments,
|
|
32
|
-
)
|
|
33
|
-
for fragment1, fragment2 in matched_fragments:
|
|
34
|
-
if fragment1.rank > fragment2.rank:
|
|
35
|
-
corrected_fragments.append(fragment1)
|
|
36
|
-
else:
|
|
37
|
-
corrected_fragments.append(fragment2)
|
|
38
|
-
|
|
39
|
-
corrected_fragments.extend(not_matched_fragments)
|
|
40
|
-
layout.fragments = corrected_fragments
|
|
41
|
-
|
|
42
|
-
def _adjust_image(image: Image) -> tuple[Image, int, int, float]:
|
|
43
|
-
# after testing, adding white borders to images can reduce
|
|
44
|
-
# the possibility of some text not being recognized
|
|
45
|
-
border_size: int = 50
|
|
46
|
-
adjusted_size: int = 1024 - 2 * border_size
|
|
47
|
-
width, height = image.size
|
|
48
|
-
core_width = float(max(adjusted_size, width))
|
|
49
|
-
core_height = float(max(adjusted_size, height))
|
|
50
|
-
|
|
51
|
-
scale_x = core_width / width
|
|
52
|
-
scale_y = core_height / height
|
|
53
|
-
scale = min(scale_x, scale_y)
|
|
54
|
-
adjusted_width = width * scale
|
|
55
|
-
adjusted_height = height * scale
|
|
56
|
-
|
|
57
|
-
dx = (core_width - adjusted_width) / 2.0
|
|
58
|
-
dy = (core_height - adjusted_height) / 2.0
|
|
59
|
-
dx = round(dx) + border_size
|
|
60
|
-
dy = round(dy) + border_size
|
|
61
|
-
|
|
62
|
-
if scale != 1.0:
|
|
63
|
-
width = round(width * scale)
|
|
64
|
-
height = round(height * scale)
|
|
65
|
-
image = image.resize((width, height), Resampling.BICUBIC)
|
|
66
|
-
|
|
67
|
-
width = round(core_width) + 2 * border_size
|
|
68
|
-
height = round(core_height) + 2 * border_size
|
|
69
|
-
new_image = new("RGB", (width, height), (255, 255, 255))
|
|
70
|
-
new_image.paste(image, (dx, dy))
|
|
71
|
-
|
|
72
|
-
return new_image, dx, dy, scale
|
|
73
|
-
|
|
74
|
-
def _apply_fragment(rect: Rectangle, layout: Layout, dx: int, dy: int, scale: float):
|
|
75
|
-
rect.lt = _apply_point(rect.lt, layout, dx, dy, scale)
|
|
76
|
-
rect.lb = _apply_point(rect.lb, layout, dx, dy, scale)
|
|
77
|
-
rect.rb = _apply_point(rect.rb, layout, dx, dy, scale)
|
|
78
|
-
rect.rt = _apply_point(rect.rt, layout, dx, dy, scale)
|
|
79
|
-
|
|
80
|
-
def _apply_point(point: Point, layout: Layout, dx: int, dy: int, scale: float) -> Point:
|
|
81
|
-
x, y = point
|
|
82
|
-
x = (x - dx) / scale + layout.rect.lt[0]
|
|
83
|
-
y = (y - dy) / scale + layout.rect.lt[1]
|
|
84
|
-
return x, y
|
|
85
|
-
|
|
86
|
-
def _match_fragments(
|
|
87
|
-
zone_rect: Rectangle,
|
|
88
|
-
fragments1: Iterable[OCRFragment],
|
|
89
|
-
fragments2: Iterable[OCRFragment],
|
|
90
|
-
) -> tuple[list[tuple[OCRFragment, OCRFragment]], list[OCRFragment]]:
|
|
91
|
-
|
|
92
|
-
zone_polygon = Polygon(zone_rect)
|
|
93
|
-
fragments2: list[OCRFragment] = list(fragments2)
|
|
94
|
-
matched_fragments: list[tuple[OCRFragment, OCRFragment]] = []
|
|
95
|
-
not_matched_fragments: list[OCRFragment] = []
|
|
96
|
-
|
|
97
|
-
for fragment1 in fragments1:
|
|
98
|
-
polygon1 = Polygon(fragment1.rect)
|
|
99
|
-
polygon1 = zone_polygon.intersection(polygon1)
|
|
100
|
-
if polygon1.is_empty:
|
|
101
|
-
continue
|
|
102
|
-
|
|
103
|
-
beast_j = -1
|
|
104
|
-
beast_rate = 0.0
|
|
105
|
-
|
|
106
|
-
for j, fragment2 in enumerate(fragments2):
|
|
107
|
-
polygon2 = Polygon(fragment2.rect)
|
|
108
|
-
rate = overlap_rate(polygon1, polygon2)
|
|
109
|
-
if rate < _MIN_RATE:
|
|
110
|
-
continue
|
|
111
|
-
|
|
112
|
-
if rate > beast_rate:
|
|
113
|
-
beast_j = j
|
|
114
|
-
beast_rate = rate
|
|
115
|
-
|
|
116
|
-
if beast_j != -1:
|
|
117
|
-
matched_fragments.append((
|
|
118
|
-
fragment1,
|
|
119
|
-
fragments2[beast_j],
|
|
120
|
-
))
|
|
121
|
-
del fragments2[beast_j]
|
|
122
|
-
else:
|
|
123
|
-
not_matched_fragments.append(fragment1)
|
|
124
|
-
|
|
125
|
-
not_matched_fragments.extend(fragments2)
|
|
126
|
-
return matched_fragments, not_matched_fragments
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .predict_system import TextSystem
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
class ClsPostProcess (object):
|
|
2
|
-
""" Convert between text-label and text-index """
|
|
3
|
-
|
|
4
|
-
def __init__(self, label_list=None, key=None, **kwargs):
|
|
5
|
-
super(ClsPostProcess, self).__init__()
|
|
6
|
-
self.label_list = label_list
|
|
7
|
-
self.key = key
|
|
8
|
-
|
|
9
|
-
def __call__(self, preds, label=None, *args, **kwargs):
|
|
10
|
-
if self.key is not None:
|
|
11
|
-
preds = preds[self.key]
|
|
12
|
-
|
|
13
|
-
label_list = self.label_list
|
|
14
|
-
if label_list is None:
|
|
15
|
-
label_list = {idx: idx for idx in range(preds.shape[-1])}
|
|
16
|
-
|
|
17
|
-
# if isinstance(preds, paddle.Tensor):
|
|
18
|
-
# preds = preds.numpy()
|
|
19
|
-
|
|
20
|
-
pred_idxs = preds.argmax(axis=1)
|
|
21
|
-
decode_out = [(label_list[idx], preds[i, idx])
|
|
22
|
-
for i, idx in enumerate(pred_idxs)]
|
|
23
|
-
if label is None:
|
|
24
|
-
return decode_out
|
|
25
|
-
label = [(label_list[idx], 1.0) for idx in label]
|
|
26
|
-
return decode_out, label
|