doc-page-extractor 0.1.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc_page_extractor/__init__.py +5 -14
- doc_page_extractor/check_env.py +40 -0
- doc_page_extractor/extractor.py +87 -212
- doc_page_extractor/model.py +97 -0
- doc_page_extractor/parser.py +51 -0
- doc_page_extractor/plot.py +52 -79
- doc_page_extractor/redacter.py +111 -0
- doc_page_extractor-1.0.2.dist-info/METADATA +120 -0
- doc_page_extractor-1.0.2.dist-info/RECORD +11 -0
- {doc_page_extractor-0.1.1.dist-info → doc_page_extractor-1.0.2.dist-info}/WHEEL +1 -2
- doc_page_extractor-1.0.2.dist-info/licenses/LICENSE +21 -0
- doc_page_extractor/clipper.py +0 -119
- doc_page_extractor/downloader.py +0 -16
- doc_page_extractor/latex.py +0 -57
- doc_page_extractor/layout_order.py +0 -240
- doc_page_extractor/layoutreader.py +0 -126
- doc_page_extractor/ocr.py +0 -175
- doc_page_extractor/ocr_corrector.py +0 -126
- doc_page_extractor/onnxocr/__init__.py +0 -1
- doc_page_extractor/onnxocr/cls_postprocess.py +0 -26
- doc_page_extractor/onnxocr/db_postprocess.py +0 -246
- doc_page_extractor/onnxocr/imaug.py +0 -32
- doc_page_extractor/onnxocr/operators.py +0 -187
- doc_page_extractor/onnxocr/predict_base.py +0 -52
- doc_page_extractor/onnxocr/predict_cls.py +0 -89
- doc_page_extractor/onnxocr/predict_det.py +0 -120
- doc_page_extractor/onnxocr/predict_rec.py +0 -321
- doc_page_extractor/onnxocr/predict_system.py +0 -97
- doc_page_extractor/onnxocr/rec_postprocess.py +0 -896
- doc_page_extractor/onnxocr/utils.py +0 -71
- doc_page_extractor/overlap.py +0 -167
- doc_page_extractor/raw_optimizer.py +0 -104
- doc_page_extractor/rectangle.py +0 -72
- doc_page_extractor/rotation.py +0 -158
- doc_page_extractor/struct_eqtable/__init__.py +0 -49
- doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
- doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
- doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
- doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
- doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
- doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
- doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
- doc_page_extractor/table.py +0 -71
- doc_page_extractor/types.py +0 -67
- doc_page_extractor/utils.py +0 -32
- doc_page_extractor-0.1.1.dist-info/METADATA +0 -84
- doc_page_extractor-0.1.1.dist-info/RECORD +0 -44
- doc_page_extractor-0.1.1.dist-info/licenses/LICENSE +0 -661
- doc_page_extractor-0.1.1.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -0
- tests/test_history_bus.py +0 -55
|
@@ -1,240 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from typing import Generator
|
|
5
|
-
from dataclasses import dataclass
|
|
6
|
-
from transformers import LayoutLMv3ForTokenClassification
|
|
7
|
-
|
|
8
|
-
from .types import Layout, LayoutClass
|
|
9
|
-
from .layoutreader import prepare_inputs, boxes2inputs, parse_logits
|
|
10
|
-
from .utils import ensure_dir
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@dataclass
|
|
14
|
-
class _BBox:
|
|
15
|
-
layout_index: int
|
|
16
|
-
fragment_index: int
|
|
17
|
-
virtual: bool
|
|
18
|
-
order: int
|
|
19
|
-
value: tuple[float, float, float, float]
|
|
20
|
-
|
|
21
|
-
class LayoutOrder:
|
|
22
|
-
def __init__(self, model_path: str):
|
|
23
|
-
self._model_path: str = model_path
|
|
24
|
-
self._model: LayoutLMv3ForTokenClassification | None = None
|
|
25
|
-
|
|
26
|
-
def _get_model(self) -> LayoutLMv3ForTokenClassification:
|
|
27
|
-
if self._model is None:
|
|
28
|
-
model_path = ensure_dir(self._model_path)
|
|
29
|
-
self._model = LayoutLMv3ForTokenClassification.from_pretrained(
|
|
30
|
-
pretrained_model_name_or_path="hantian/layoutreader",
|
|
31
|
-
cache_dir=model_path,
|
|
32
|
-
local_files_only=os.path.exists(os.path.join(model_path, "models--hantian--layoutreader")),
|
|
33
|
-
)
|
|
34
|
-
return self._model
|
|
35
|
-
|
|
36
|
-
def sort(self, layouts: list[Layout], size: tuple[int, int]) -> list[Layout]:
|
|
37
|
-
width, height = size
|
|
38
|
-
if width == 0 or height == 0:
|
|
39
|
-
return layouts
|
|
40
|
-
|
|
41
|
-
bbox_list = self._order_and_get_bbox_list(
|
|
42
|
-
layouts=layouts,
|
|
43
|
-
width=width,
|
|
44
|
-
height=height,
|
|
45
|
-
)
|
|
46
|
-
if bbox_list is None:
|
|
47
|
-
return layouts
|
|
48
|
-
|
|
49
|
-
return self._sort_layouts_and_fragments(layouts, bbox_list)
|
|
50
|
-
|
|
51
|
-
def _order_and_get_bbox_list(
|
|
52
|
-
self,
|
|
53
|
-
layouts: list[Layout],
|
|
54
|
-
width: int,
|
|
55
|
-
height: int,
|
|
56
|
-
) -> list[_BBox] | None:
|
|
57
|
-
|
|
58
|
-
line_height = self._line_height(layouts)
|
|
59
|
-
bbox_list: list[_BBox] = []
|
|
60
|
-
|
|
61
|
-
for i, layout in enumerate(layouts):
|
|
62
|
-
if layout.cls == LayoutClass.PLAIN_TEXT and \
|
|
63
|
-
len(layout.fragments) > 0:
|
|
64
|
-
for j, fragment in enumerate(layout.fragments):
|
|
65
|
-
bbox_list.append(_BBox(
|
|
66
|
-
layout_index=i,
|
|
67
|
-
fragment_index=j,
|
|
68
|
-
virtual=False,
|
|
69
|
-
order=0,
|
|
70
|
-
value=fragment.rect.wrapper,
|
|
71
|
-
))
|
|
72
|
-
else:
|
|
73
|
-
bbox_list.extend(
|
|
74
|
-
self._generate_virtual_lines(
|
|
75
|
-
layout=layout,
|
|
76
|
-
layout_index=i,
|
|
77
|
-
line_height=line_height,
|
|
78
|
-
width=width,
|
|
79
|
-
height=height,
|
|
80
|
-
),
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
if len(bbox_list) > 200:
|
|
84
|
-
# https://github.com/opendatalab/MinerU/blob/980f5c8cd70f22f8c0c9b7b40eaff6f4804e6524/magic_pdf/pdf_parse_union_core_v2.py#L522
|
|
85
|
-
return None
|
|
86
|
-
|
|
87
|
-
layoutreader_size = 1000.0
|
|
88
|
-
x_scale = layoutreader_size / float(width)
|
|
89
|
-
y_scale = layoutreader_size / float(height)
|
|
90
|
-
|
|
91
|
-
for bbox in bbox_list:
|
|
92
|
-
x0, y0, x1, y1 = self._squeeze(bbox.value, width, height)
|
|
93
|
-
x0 = round(x0 * x_scale)
|
|
94
|
-
y0 = round(y0 * y_scale)
|
|
95
|
-
x1 = round(x1 * x_scale)
|
|
96
|
-
y1 = round(y1 * y_scale)
|
|
97
|
-
bbox.value = (x0, y0, x1, y1)
|
|
98
|
-
|
|
99
|
-
bbox_list.sort(key=lambda b: b.value) # 必须排序,乱序传入 layoutreader 会令它无法识别正确顺序
|
|
100
|
-
model = self._get_model()
|
|
101
|
-
|
|
102
|
-
with torch.no_grad():
|
|
103
|
-
inputs = boxes2inputs([list(bbox.value) for bbox in bbox_list])
|
|
104
|
-
inputs = prepare_inputs(inputs, model)
|
|
105
|
-
logits = model(**inputs).logits.cpu().squeeze(0)
|
|
106
|
-
orders = parse_logits(logits, len(bbox_list))
|
|
107
|
-
|
|
108
|
-
sorted_bbox_list = [bbox_list[i] for i in orders]
|
|
109
|
-
for i, bbox in enumerate(sorted_bbox_list):
|
|
110
|
-
bbox.order = i
|
|
111
|
-
|
|
112
|
-
return sorted_bbox_list
|
|
113
|
-
|
|
114
|
-
def _sort_layouts_and_fragments(self, layouts: list[Layout], bbox_list: list[_BBox]):
|
|
115
|
-
layout_bbox_list: list[list[_BBox]] = [[] for _ in range(len(layouts))]
|
|
116
|
-
for bbox in bbox_list:
|
|
117
|
-
layout_bbox_list[bbox.layout_index].append(bbox)
|
|
118
|
-
|
|
119
|
-
layouts_with_median_order: list[tuple[Layout, float]] = []
|
|
120
|
-
for layout_index, bbox_list in enumerate(layout_bbox_list):
|
|
121
|
-
layout = layouts[layout_index]
|
|
122
|
-
orders = [b.order for b in bbox_list] # virtual bbox 保证了 orders 不可能为空
|
|
123
|
-
median_order = self._median(orders)
|
|
124
|
-
layouts_with_median_order.append((layout, median_order))
|
|
125
|
-
|
|
126
|
-
for layout, bbox_list in zip(layouts, layout_bbox_list):
|
|
127
|
-
for bbox in bbox_list:
|
|
128
|
-
if not bbox.virtual:
|
|
129
|
-
layout.fragments[bbox.fragment_index].order = bbox.order
|
|
130
|
-
if all(not bbox.virtual for bbox in bbox_list):
|
|
131
|
-
layout.fragments.sort(key=lambda f: f.order)
|
|
132
|
-
|
|
133
|
-
layouts_with_median_order.sort(key=lambda x: x[1])
|
|
134
|
-
layouts = [layout for layout, _ in layouts_with_median_order]
|
|
135
|
-
next_fragment_order: int = 0
|
|
136
|
-
|
|
137
|
-
for layout in layouts:
|
|
138
|
-
for fragment in layout.fragments:
|
|
139
|
-
fragment.order = next_fragment_order
|
|
140
|
-
next_fragment_order += 1
|
|
141
|
-
|
|
142
|
-
return layouts
|
|
143
|
-
|
|
144
|
-
def _line_height(self, layouts: list[Layout]) -> float:
|
|
145
|
-
line_height: float = 0.0
|
|
146
|
-
count: int = 0
|
|
147
|
-
for layout in layouts:
|
|
148
|
-
for fragment in layout.fragments:
|
|
149
|
-
_, height = fragment.rect.size
|
|
150
|
-
line_height += height
|
|
151
|
-
count += 1
|
|
152
|
-
if count == 0:
|
|
153
|
-
return 10.0
|
|
154
|
-
return line_height / float(count)
|
|
155
|
-
|
|
156
|
-
def _generate_virtual_lines(
|
|
157
|
-
self,
|
|
158
|
-
layout: Layout,
|
|
159
|
-
layout_index: int,
|
|
160
|
-
line_height: float,
|
|
161
|
-
width: int,
|
|
162
|
-
height: int,
|
|
163
|
-
) -> Generator[_BBox, None, None]:
|
|
164
|
-
|
|
165
|
-
# https://github.com/opendatalab/MinerU/blob/980f5c8cd70f22f8c0c9b7b40eaff6f4804e6524/magic_pdf/pdf_parse_union_core_v2.py#L451-L490
|
|
166
|
-
x0, y0, x1, y1 = layout.rect.wrapper
|
|
167
|
-
layout_height = y1 - y0
|
|
168
|
-
layout_weight = x1 - x0
|
|
169
|
-
lines = int(layout_height / line_height)
|
|
170
|
-
|
|
171
|
-
if layout_height <= line_height * 2:
|
|
172
|
-
yield _BBox(
|
|
173
|
-
layout_index=layout_index,
|
|
174
|
-
fragment_index=0,
|
|
175
|
-
virtual=True,
|
|
176
|
-
order=0,
|
|
177
|
-
value=(x0, y0, x1, y1),
|
|
178
|
-
)
|
|
179
|
-
return
|
|
180
|
-
|
|
181
|
-
elif layout_height <= height * 0.25 or \
|
|
182
|
-
width * 0.5 <= layout_weight or \
|
|
183
|
-
width * 0.25 < layout_weight:
|
|
184
|
-
if layout_weight > width * 0.4:
|
|
185
|
-
lines = 3
|
|
186
|
-
elif layout_weight <= width * 0.25:
|
|
187
|
-
if layout_height / layout_weight > 1.2: # 细长的不分
|
|
188
|
-
yield _BBox(
|
|
189
|
-
layout_index=layout_index,
|
|
190
|
-
fragment_index=0,
|
|
191
|
-
virtual=True,
|
|
192
|
-
order=0,
|
|
193
|
-
value=(x0, y0, x1, y1),
|
|
194
|
-
)
|
|
195
|
-
return
|
|
196
|
-
else: # 不细长的还是分成两行
|
|
197
|
-
lines = 2
|
|
198
|
-
|
|
199
|
-
lines = max(1, lines)
|
|
200
|
-
line_height = (y1 - y0) / lines
|
|
201
|
-
current_y = y0
|
|
202
|
-
|
|
203
|
-
for i in range(lines):
|
|
204
|
-
yield _BBox(
|
|
205
|
-
layout_index=layout_index,
|
|
206
|
-
fragment_index=i,
|
|
207
|
-
virtual=True,
|
|
208
|
-
order=0,
|
|
209
|
-
value=(x0, current_y, x1, current_y + line_height),
|
|
210
|
-
)
|
|
211
|
-
current_y += line_height
|
|
212
|
-
|
|
213
|
-
def _median(self, numbers: list[int]) -> float:
|
|
214
|
-
sorted_numbers = sorted(numbers)
|
|
215
|
-
n = len(sorted_numbers)
|
|
216
|
-
|
|
217
|
-
# 判断是奇数还是偶数个元素
|
|
218
|
-
if n % 2 == 1:
|
|
219
|
-
# 奇数情况,直接取中间的数
|
|
220
|
-
return float(sorted_numbers[n // 2])
|
|
221
|
-
else:
|
|
222
|
-
# 偶数情况,取中间两个数的平均值
|
|
223
|
-
mid1 = sorted_numbers[n // 2 - 1]
|
|
224
|
-
mid2 = sorted_numbers[n // 2]
|
|
225
|
-
return float((mid1 + mid2) / 2)
|
|
226
|
-
|
|
227
|
-
def _squeeze(self, bbox: _BBox, width: int, height: int) -> _BBox:
|
|
228
|
-
x0, y0, x1, y1 = bbox
|
|
229
|
-
x0 = self._squeeze_value(x0, width)
|
|
230
|
-
x1 = self._squeeze_value(x1, width)
|
|
231
|
-
y0 = self._squeeze_value(y0, height)
|
|
232
|
-
y1 = self._squeeze_value(y1, height)
|
|
233
|
-
return x0, y0, x1, y1
|
|
234
|
-
|
|
235
|
-
def _squeeze_value(self, position: float, size: int) -> float:
|
|
236
|
-
if position < 0:
|
|
237
|
-
position = 0.0
|
|
238
|
-
if position > size:
|
|
239
|
-
position = float(size)
|
|
240
|
-
return position
|
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
# Copy from https://github.com/ppaanngggg/layoutreader/blob/main/v3/helpers.py
|
|
2
|
-
from collections import defaultdict
|
|
3
|
-
from typing import List, Dict
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from transformers import LayoutLMv3ForTokenClassification
|
|
7
|
-
|
|
8
|
-
MAX_LEN = 510
|
|
9
|
-
CLS_TOKEN_ID = 0
|
|
10
|
-
UNK_TOKEN_ID = 3
|
|
11
|
-
EOS_TOKEN_ID = 2
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class DataCollator:
|
|
15
|
-
def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
|
|
16
|
-
bbox = []
|
|
17
|
-
labels = []
|
|
18
|
-
input_ids = []
|
|
19
|
-
attention_mask = []
|
|
20
|
-
|
|
21
|
-
# clip bbox and labels to max length, build input_ids and attention_mask
|
|
22
|
-
for feature in features:
|
|
23
|
-
_bbox = feature["source_boxes"]
|
|
24
|
-
if len(_bbox) > MAX_LEN:
|
|
25
|
-
_bbox = _bbox[:MAX_LEN]
|
|
26
|
-
_labels = feature["target_index"]
|
|
27
|
-
if len(_labels) > MAX_LEN:
|
|
28
|
-
_labels = _labels[:MAX_LEN]
|
|
29
|
-
_input_ids = [UNK_TOKEN_ID] * len(_bbox)
|
|
30
|
-
_attention_mask = [1] * len(_bbox)
|
|
31
|
-
assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
|
|
32
|
-
bbox.append(_bbox)
|
|
33
|
-
labels.append(_labels)
|
|
34
|
-
input_ids.append(_input_ids)
|
|
35
|
-
attention_mask.append(_attention_mask)
|
|
36
|
-
|
|
37
|
-
# add CLS and EOS tokens
|
|
38
|
-
for i in range(len(bbox)):
|
|
39
|
-
bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
|
|
40
|
-
labels[i] = [-100] + labels[i] + [-100]
|
|
41
|
-
input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
|
|
42
|
-
attention_mask[i] = [1] + attention_mask[i] + [1]
|
|
43
|
-
|
|
44
|
-
# padding to max length
|
|
45
|
-
max_len = max(len(x) for x in bbox)
|
|
46
|
-
for i in range(len(bbox)):
|
|
47
|
-
bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
|
|
48
|
-
labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
|
|
49
|
-
input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
|
|
50
|
-
attention_mask[i] = attention_mask[i] + [0] * (
|
|
51
|
-
max_len - len(attention_mask[i])
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
ret = {
|
|
55
|
-
"bbox": torch.tensor(bbox),
|
|
56
|
-
"attention_mask": torch.tensor(attention_mask),
|
|
57
|
-
"labels": torch.tensor(labels),
|
|
58
|
-
"input_ids": torch.tensor(input_ids),
|
|
59
|
-
}
|
|
60
|
-
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
|
|
61
|
-
ret["labels"][ret["labels"] > MAX_LEN] = -100
|
|
62
|
-
# set label > 0 to label-1, because original labels are 1-indexed
|
|
63
|
-
ret["labels"][ret["labels"] > 0] -= 1
|
|
64
|
-
return ret
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
|
|
68
|
-
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
|
|
69
|
-
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
|
|
70
|
-
attention_mask = [1] + [1] * len(boxes) + [1]
|
|
71
|
-
return {
|
|
72
|
-
"bbox": torch.tensor([bbox]),
|
|
73
|
-
"attention_mask": torch.tensor([attention_mask]),
|
|
74
|
-
"input_ids": torch.tensor([input_ids]),
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def prepare_inputs(
|
|
79
|
-
inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
|
|
80
|
-
) -> Dict[str, torch.Tensor]:
|
|
81
|
-
ret = {}
|
|
82
|
-
for k, v in inputs.items():
|
|
83
|
-
v = v.to(model.device)
|
|
84
|
-
if torch.is_floating_point(v):
|
|
85
|
-
v = v.to(model.dtype)
|
|
86
|
-
ret[k] = v
|
|
87
|
-
return ret
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
|
|
91
|
-
"""
|
|
92
|
-
parse logits to orders
|
|
93
|
-
|
|
94
|
-
:param logits: logits from model
|
|
95
|
-
:param length: input length
|
|
96
|
-
:return: orders
|
|
97
|
-
"""
|
|
98
|
-
logits = logits[1 : length + 1, :length]
|
|
99
|
-
orders = logits.argsort(descending=False).tolist()
|
|
100
|
-
ret = [o.pop() for o in orders]
|
|
101
|
-
while True:
|
|
102
|
-
order_to_idxes = defaultdict(list)
|
|
103
|
-
for idx, order in enumerate(ret):
|
|
104
|
-
order_to_idxes[order].append(idx)
|
|
105
|
-
# filter idxes len > 1
|
|
106
|
-
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
|
|
107
|
-
if not order_to_idxes:
|
|
108
|
-
break
|
|
109
|
-
# filter
|
|
110
|
-
for order, idxes in order_to_idxes.items():
|
|
111
|
-
# find original logits of idxes
|
|
112
|
-
idxes_to_logit = {}
|
|
113
|
-
for idx in idxes:
|
|
114
|
-
idxes_to_logit[idx] = logits[idx, order]
|
|
115
|
-
idxes_to_logit = sorted(
|
|
116
|
-
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
|
|
117
|
-
)
|
|
118
|
-
# keep the highest logit as order, set others to next candidate
|
|
119
|
-
for idx, _ in idxes_to_logit[1:]:
|
|
120
|
-
ret[idx] = orders[idx].pop()
|
|
121
|
-
|
|
122
|
-
return ret
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def check_duplicate(a: List[int]) -> bool:
|
|
126
|
-
return len(a) != len(set(a))
|
doc_page_extractor/ocr.py
DELETED
|
@@ -1,175 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import cv2
|
|
3
|
-
import os
|
|
4
|
-
|
|
5
|
-
from typing import Literal, Generator
|
|
6
|
-
from dataclasses import dataclass
|
|
7
|
-
from .onnxocr import TextSystem
|
|
8
|
-
from .types import OCRFragment
|
|
9
|
-
from .rectangle import Rectangle
|
|
10
|
-
from .downloader import download
|
|
11
|
-
from .utils import is_space_text
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
_MODELS = (
|
|
15
|
-
("ppocrv4", "rec", "rec.onnx"),
|
|
16
|
-
("ppocrv4", "cls", "cls.onnx"),
|
|
17
|
-
("ppocrv4", "det", "det.onnx"),
|
|
18
|
-
("ch_ppocr_server_v2.0", "ppocr_keys_v1.txt"),
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
@dataclass
|
|
22
|
-
class _OONXParams:
|
|
23
|
-
use_angle_cls: bool
|
|
24
|
-
use_gpu: bool
|
|
25
|
-
rec_image_shape: tuple[int, int, int]
|
|
26
|
-
cls_image_shape: tuple[int, int, int]
|
|
27
|
-
cls_batch_num: int
|
|
28
|
-
cls_thresh: float
|
|
29
|
-
label_list: list[str]
|
|
30
|
-
|
|
31
|
-
det_algorithm: str
|
|
32
|
-
det_limit_side_len: int
|
|
33
|
-
det_limit_type: str
|
|
34
|
-
det_db_thresh: float
|
|
35
|
-
det_db_box_thresh: float
|
|
36
|
-
det_db_unclip_ratio: float
|
|
37
|
-
use_dilation: bool
|
|
38
|
-
det_db_score_mode: str
|
|
39
|
-
det_box_type: str
|
|
40
|
-
rec_batch_num: int
|
|
41
|
-
drop_score: float
|
|
42
|
-
save_crop_res: bool
|
|
43
|
-
rec_algorithm: str
|
|
44
|
-
use_space_char: bool
|
|
45
|
-
rec_model_dir: str
|
|
46
|
-
cls_model_dir: str
|
|
47
|
-
det_model_dir: str
|
|
48
|
-
rec_char_dict_path: str
|
|
49
|
-
|
|
50
|
-
class OCR:
|
|
51
|
-
def __init__(
|
|
52
|
-
self,
|
|
53
|
-
device: Literal["cpu", "cuda"],
|
|
54
|
-
model_dir_path: str,
|
|
55
|
-
):
|
|
56
|
-
self._device: Literal["cpu", "cuda"] = device
|
|
57
|
-
self._model_dir_path: str = model_dir_path
|
|
58
|
-
self._text_system: TextSystem | None = None
|
|
59
|
-
|
|
60
|
-
def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
|
|
61
|
-
for box, res in self._ocr(image):
|
|
62
|
-
text, rank = res
|
|
63
|
-
if is_space_text(text):
|
|
64
|
-
continue
|
|
65
|
-
|
|
66
|
-
rect = Rectangle(
|
|
67
|
-
lt=(box[0][0], box[0][1]),
|
|
68
|
-
rt=(box[1][0], box[1][1]),
|
|
69
|
-
rb=(box[2][0], box[2][1]),
|
|
70
|
-
lb=(box[3][0], box[3][1]),
|
|
71
|
-
)
|
|
72
|
-
if not rect.is_valid or rect.area == 0.0:
|
|
73
|
-
continue
|
|
74
|
-
|
|
75
|
-
yield OCRFragment(
|
|
76
|
-
order=0,
|
|
77
|
-
text=text,
|
|
78
|
-
rank=rank,
|
|
79
|
-
rect=rect,
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
def _ocr(self, image: np.ndarray) -> Generator[tuple[list[list[float]], tuple[str, float]], None, None]:
|
|
83
|
-
text_system = self._get_text_system()
|
|
84
|
-
image = self._preprocess_image(image)
|
|
85
|
-
dt_boxes, rec_res = text_system(image)
|
|
86
|
-
|
|
87
|
-
for box, res in zip(dt_boxes, rec_res):
|
|
88
|
-
yield box.tolist(), res
|
|
89
|
-
|
|
90
|
-
def _get_text_system(self) -> TextSystem:
|
|
91
|
-
if self._text_system is None:
|
|
92
|
-
for model_path in _MODELS:
|
|
93
|
-
file_path = os.path.join(self._model_dir_path, *model_path)
|
|
94
|
-
if os.path.exists(file_path):
|
|
95
|
-
continue
|
|
96
|
-
|
|
97
|
-
file_dir_path = os.path.dirname(file_path)
|
|
98
|
-
os.makedirs(file_dir_path, exist_ok=True)
|
|
99
|
-
|
|
100
|
-
url_path = "/".join(model_path)
|
|
101
|
-
url = f"https://huggingface.co/moskize/OnnxOCR/resolve/main/{url_path}"
|
|
102
|
-
download(url, file_path)
|
|
103
|
-
|
|
104
|
-
self._text_system = TextSystem(_OONXParams(
|
|
105
|
-
use_angle_cls=True,
|
|
106
|
-
use_gpu=(self._device != "cpu"),
|
|
107
|
-
rec_image_shape=(3, 48, 320),
|
|
108
|
-
cls_image_shape=(3, 48, 192),
|
|
109
|
-
cls_batch_num=6,
|
|
110
|
-
cls_thresh=0.9,
|
|
111
|
-
label_list=["0", "180"],
|
|
112
|
-
det_algorithm="DB",
|
|
113
|
-
det_limit_side_len=960,
|
|
114
|
-
det_limit_type="max",
|
|
115
|
-
det_db_thresh=0.3,
|
|
116
|
-
det_db_box_thresh=0.6,
|
|
117
|
-
det_db_unclip_ratio=1.5,
|
|
118
|
-
use_dilation=False,
|
|
119
|
-
det_db_score_mode="fast",
|
|
120
|
-
det_box_type="quad",
|
|
121
|
-
rec_batch_num=6,
|
|
122
|
-
drop_score=0.5,
|
|
123
|
-
save_crop_res=False,
|
|
124
|
-
rec_algorithm="SVTR_LCNet",
|
|
125
|
-
use_space_char=True,
|
|
126
|
-
rec_model_dir=os.path.join(self._model_dir_path, *_MODELS[0]),
|
|
127
|
-
cls_model_dir=os.path.join(self._model_dir_path, *_MODELS[1]),
|
|
128
|
-
det_model_dir=os.path.join(self._model_dir_path, *_MODELS[2]),
|
|
129
|
-
rec_char_dict_path=os.path.join(self._model_dir_path, *_MODELS[3]),
|
|
130
|
-
))
|
|
131
|
-
|
|
132
|
-
return self._text_system
|
|
133
|
-
|
|
134
|
-
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
|
|
135
|
-
image = self._alpha_to_color(image, (255, 255, 255))
|
|
136
|
-
# image = cv2.bitwise_not(image) # inv
|
|
137
|
-
# image = self._binarize_img(image) # bin
|
|
138
|
-
image = cv2.normalize(
|
|
139
|
-
src=image,
|
|
140
|
-
dst=np.zeros((image.shape[0], image.shape[1])),
|
|
141
|
-
alpha=0,
|
|
142
|
-
beta=255,
|
|
143
|
-
norm_type=cv2.NORM_MINMAX,
|
|
144
|
-
)
|
|
145
|
-
image = cv2.fastNlMeansDenoisingColored(
|
|
146
|
-
src=image,
|
|
147
|
-
dst=None,
|
|
148
|
-
h=10,
|
|
149
|
-
hColor=10,
|
|
150
|
-
templateWindowSize=7,
|
|
151
|
-
searchWindowSize=15,
|
|
152
|
-
)
|
|
153
|
-
# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # image to gray
|
|
154
|
-
return image
|
|
155
|
-
|
|
156
|
-
def _alpha_to_color(self, image: np.ndarray, alpha_color: tuple[float, float, float]) -> np.ndarray:
|
|
157
|
-
if len(image.shape) == 3 and image.shape[2] == 4:
|
|
158
|
-
B, G, R, A = cv2.split(image)
|
|
159
|
-
alpha = A / 255
|
|
160
|
-
|
|
161
|
-
R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
|
|
162
|
-
G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
|
|
163
|
-
B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
|
|
164
|
-
|
|
165
|
-
image = cv2.merge((B, G, R))
|
|
166
|
-
|
|
167
|
-
return image
|
|
168
|
-
|
|
169
|
-
def _binarize_img(self, image: np.ndarray):
|
|
170
|
-
if len(image.shape) == 3 and image.shape[2] == 3:
|
|
171
|
-
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # conversion to grayscale image
|
|
172
|
-
# use cv2 threshold binarization
|
|
173
|
-
_, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
174
|
-
image = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
|
175
|
-
return image
|
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
|
|
3
|
-
from typing import Iterable
|
|
4
|
-
from shapely.geometry import Polygon
|
|
5
|
-
from PIL.Image import new, Image, Resampling
|
|
6
|
-
from .types import Layout, OCRFragment
|
|
7
|
-
from .ocr import OCR
|
|
8
|
-
from .overlap import overlap_rate
|
|
9
|
-
from .rectangle import Point, Rectangle
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
_MIN_RATE = 0.5
|
|
13
|
-
|
|
14
|
-
def correct_fragments(ocr: OCR, source: Image, layout: Layout):
|
|
15
|
-
x1, y1, x2, y2 = layout.rect.wrapper
|
|
16
|
-
image: Image = source.crop((
|
|
17
|
-
round(x1), round(y1),
|
|
18
|
-
round(x2), round(y2),
|
|
19
|
-
))
|
|
20
|
-
image, dx, dy, scale = _adjust_image(image)
|
|
21
|
-
image_np = np.array(image)
|
|
22
|
-
ocr_fragments = list(ocr.search_fragments(image_np))
|
|
23
|
-
corrected_fragments: list[OCRFragment] = []
|
|
24
|
-
|
|
25
|
-
for fragment in ocr_fragments:
|
|
26
|
-
_apply_fragment(fragment.rect, layout, dx, dy, scale)
|
|
27
|
-
|
|
28
|
-
matched_fragments, not_matched_fragments = _match_fragments(
|
|
29
|
-
zone_rect=layout.rect,
|
|
30
|
-
fragments1=layout.fragments,
|
|
31
|
-
fragments2=ocr_fragments,
|
|
32
|
-
)
|
|
33
|
-
for fragment1, fragment2 in matched_fragments:
|
|
34
|
-
if fragment1.rank > fragment2.rank:
|
|
35
|
-
corrected_fragments.append(fragment1)
|
|
36
|
-
else:
|
|
37
|
-
corrected_fragments.append(fragment2)
|
|
38
|
-
|
|
39
|
-
corrected_fragments.extend(not_matched_fragments)
|
|
40
|
-
layout.fragments = corrected_fragments
|
|
41
|
-
|
|
42
|
-
def _adjust_image(image: Image) -> tuple[Image, int, int, float]:
|
|
43
|
-
# after testing, adding white borders to images can reduce
|
|
44
|
-
# the possibility of some text not being recognized
|
|
45
|
-
border_size: int = 50
|
|
46
|
-
adjusted_size: int = 1024 - 2 * border_size
|
|
47
|
-
width, height = image.size
|
|
48
|
-
core_width = float(max(adjusted_size, width))
|
|
49
|
-
core_height = float(max(adjusted_size, height))
|
|
50
|
-
|
|
51
|
-
scale_x = core_width / width
|
|
52
|
-
scale_y = core_height / height
|
|
53
|
-
scale = min(scale_x, scale_y)
|
|
54
|
-
adjusted_width = width * scale
|
|
55
|
-
adjusted_height = height * scale
|
|
56
|
-
|
|
57
|
-
dx = (core_width - adjusted_width) / 2.0
|
|
58
|
-
dy = (core_height - adjusted_height) / 2.0
|
|
59
|
-
dx = round(dx) + border_size
|
|
60
|
-
dy = round(dy) + border_size
|
|
61
|
-
|
|
62
|
-
if scale != 1.0:
|
|
63
|
-
width = round(width * scale)
|
|
64
|
-
height = round(height * scale)
|
|
65
|
-
image = image.resize((width, height), Resampling.BICUBIC)
|
|
66
|
-
|
|
67
|
-
width = round(core_width) + 2 * border_size
|
|
68
|
-
height = round(core_height) + 2 * border_size
|
|
69
|
-
new_image = new("RGB", (width, height), (255, 255, 255))
|
|
70
|
-
new_image.paste(image, (dx, dy))
|
|
71
|
-
|
|
72
|
-
return new_image, dx, dy, scale
|
|
73
|
-
|
|
74
|
-
def _apply_fragment(rect: Rectangle, layout: Layout, dx: int, dy: int, scale: float):
|
|
75
|
-
rect.lt = _apply_point(rect.lt, layout, dx, dy, scale)
|
|
76
|
-
rect.lb = _apply_point(rect.lb, layout, dx, dy, scale)
|
|
77
|
-
rect.rb = _apply_point(rect.rb, layout, dx, dy, scale)
|
|
78
|
-
rect.rt = _apply_point(rect.rt, layout, dx, dy, scale)
|
|
79
|
-
|
|
80
|
-
def _apply_point(point: Point, layout: Layout, dx: int, dy: int, scale: float) -> Point:
|
|
81
|
-
x, y = point
|
|
82
|
-
x = (x - dx) / scale + layout.rect.lt[0]
|
|
83
|
-
y = (y - dy) / scale + layout.rect.lt[1]
|
|
84
|
-
return x, y
|
|
85
|
-
|
|
86
|
-
def _match_fragments(
|
|
87
|
-
zone_rect: Rectangle,
|
|
88
|
-
fragments1: Iterable[OCRFragment],
|
|
89
|
-
fragments2: Iterable[OCRFragment],
|
|
90
|
-
) -> tuple[list[tuple[OCRFragment, OCRFragment]], list[OCRFragment]]:
|
|
91
|
-
|
|
92
|
-
zone_polygon = Polygon(zone_rect)
|
|
93
|
-
fragments2: list[OCRFragment] = list(fragments2)
|
|
94
|
-
matched_fragments: list[tuple[OCRFragment, OCRFragment]] = []
|
|
95
|
-
not_matched_fragments: list[OCRFragment] = []
|
|
96
|
-
|
|
97
|
-
for fragment1 in fragments1:
|
|
98
|
-
polygon1 = Polygon(fragment1.rect)
|
|
99
|
-
polygon1 = zone_polygon.intersection(polygon1)
|
|
100
|
-
if polygon1.is_empty:
|
|
101
|
-
continue
|
|
102
|
-
|
|
103
|
-
beast_j = -1
|
|
104
|
-
beast_rate = 0.0
|
|
105
|
-
|
|
106
|
-
for j, fragment2 in enumerate(fragments2):
|
|
107
|
-
polygon2 = Polygon(fragment2.rect)
|
|
108
|
-
rate = overlap_rate(polygon1, polygon2)
|
|
109
|
-
if rate < _MIN_RATE:
|
|
110
|
-
continue
|
|
111
|
-
|
|
112
|
-
if rate > beast_rate:
|
|
113
|
-
beast_j = j
|
|
114
|
-
beast_rate = rate
|
|
115
|
-
|
|
116
|
-
if beast_j != -1:
|
|
117
|
-
matched_fragments.append((
|
|
118
|
-
fragment1,
|
|
119
|
-
fragments2[beast_j],
|
|
120
|
-
))
|
|
121
|
-
del fragments2[beast_j]
|
|
122
|
-
else:
|
|
123
|
-
not_matched_fragments.append(fragment1)
|
|
124
|
-
|
|
125
|
-
not_matched_fragments.extend(fragments2)
|
|
126
|
-
return matched_fragments, not_matched_fragments
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .predict_system import TextSystem
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
class ClsPostProcess (object):
|
|
2
|
-
""" Convert between text-label and text-index """
|
|
3
|
-
|
|
4
|
-
def __init__(self, label_list=None, key=None, **kwargs):
|
|
5
|
-
super(ClsPostProcess, self).__init__()
|
|
6
|
-
self.label_list = label_list
|
|
7
|
-
self.key = key
|
|
8
|
-
|
|
9
|
-
def __call__(self, preds, label=None, *args, **kwargs):
|
|
10
|
-
if self.key is not None:
|
|
11
|
-
preds = preds[self.key]
|
|
12
|
-
|
|
13
|
-
label_list = self.label_list
|
|
14
|
-
if label_list is None:
|
|
15
|
-
label_list = {idx: idx for idx in range(preds.shape[-1])}
|
|
16
|
-
|
|
17
|
-
# if isinstance(preds, paddle.Tensor):
|
|
18
|
-
# preds = preds.numpy()
|
|
19
|
-
|
|
20
|
-
pred_idxs = preds.argmax(axis=1)
|
|
21
|
-
decode_out = [(label_list[idx], preds[i, idx])
|
|
22
|
-
for i, idx in enumerate(pred_idxs)]
|
|
23
|
-
if label is None:
|
|
24
|
-
return decode_out
|
|
25
|
-
label = [(label_list[idx], 1.0) for idx in label]
|
|
26
|
-
return decode_out, label
|