doc-page-extractor 0.1.1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. doc_page_extractor/__init__.py +5 -14
  2. doc_page_extractor/check_env.py +40 -0
  3. doc_page_extractor/extractor.py +87 -212
  4. doc_page_extractor/model.py +97 -0
  5. doc_page_extractor/parser.py +51 -0
  6. doc_page_extractor/plot.py +52 -79
  7. doc_page_extractor/redacter.py +111 -0
  8. doc_page_extractor-1.0.2.dist-info/METADATA +120 -0
  9. doc_page_extractor-1.0.2.dist-info/RECORD +11 -0
  10. {doc_page_extractor-0.1.1.dist-info → doc_page_extractor-1.0.2.dist-info}/WHEEL +1 -2
  11. doc_page_extractor-1.0.2.dist-info/licenses/LICENSE +21 -0
  12. doc_page_extractor/clipper.py +0 -119
  13. doc_page_extractor/downloader.py +0 -16
  14. doc_page_extractor/latex.py +0 -57
  15. doc_page_extractor/layout_order.py +0 -240
  16. doc_page_extractor/layoutreader.py +0 -126
  17. doc_page_extractor/ocr.py +0 -175
  18. doc_page_extractor/ocr_corrector.py +0 -126
  19. doc_page_extractor/onnxocr/__init__.py +0 -1
  20. doc_page_extractor/onnxocr/cls_postprocess.py +0 -26
  21. doc_page_extractor/onnxocr/db_postprocess.py +0 -246
  22. doc_page_extractor/onnxocr/imaug.py +0 -32
  23. doc_page_extractor/onnxocr/operators.py +0 -187
  24. doc_page_extractor/onnxocr/predict_base.py +0 -52
  25. doc_page_extractor/onnxocr/predict_cls.py +0 -89
  26. doc_page_extractor/onnxocr/predict_det.py +0 -120
  27. doc_page_extractor/onnxocr/predict_rec.py +0 -321
  28. doc_page_extractor/onnxocr/predict_system.py +0 -97
  29. doc_page_extractor/onnxocr/rec_postprocess.py +0 -896
  30. doc_page_extractor/onnxocr/utils.py +0 -71
  31. doc_page_extractor/overlap.py +0 -167
  32. doc_page_extractor/raw_optimizer.py +0 -104
  33. doc_page_extractor/rectangle.py +0 -72
  34. doc_page_extractor/rotation.py +0 -158
  35. doc_page_extractor/struct_eqtable/__init__.py +0 -49
  36. doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
  37. doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
  38. doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
  39. doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
  40. doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
  41. doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
  42. doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
  43. doc_page_extractor/table.py +0 -71
  44. doc_page_extractor/types.py +0 -67
  45. doc_page_extractor/utils.py +0 -32
  46. doc_page_extractor-0.1.1.dist-info/METADATA +0 -84
  47. doc_page_extractor-0.1.1.dist-info/RECORD +0 -44
  48. doc_page_extractor-0.1.1.dist-info/licenses/LICENSE +0 -661
  49. doc_page_extractor-0.1.1.dist-info/top_level.txt +0 -2
  50. tests/__init__.py +0 -0
  51. tests/test_history_bus.py +0 -55
@@ -1,240 +0,0 @@
1
- import os
2
- import torch
3
-
4
- from typing import Generator
5
- from dataclasses import dataclass
6
- from transformers import LayoutLMv3ForTokenClassification
7
-
8
- from .types import Layout, LayoutClass
9
- from .layoutreader import prepare_inputs, boxes2inputs, parse_logits
10
- from .utils import ensure_dir
11
-
12
-
13
- @dataclass
14
- class _BBox:
15
- layout_index: int
16
- fragment_index: int
17
- virtual: bool
18
- order: int
19
- value: tuple[float, float, float, float]
20
-
21
- class LayoutOrder:
22
- def __init__(self, model_path: str):
23
- self._model_path: str = model_path
24
- self._model: LayoutLMv3ForTokenClassification | None = None
25
-
26
- def _get_model(self) -> LayoutLMv3ForTokenClassification:
27
- if self._model is None:
28
- model_path = ensure_dir(self._model_path)
29
- self._model = LayoutLMv3ForTokenClassification.from_pretrained(
30
- pretrained_model_name_or_path="hantian/layoutreader",
31
- cache_dir=model_path,
32
- local_files_only=os.path.exists(os.path.join(model_path, "models--hantian--layoutreader")),
33
- )
34
- return self._model
35
-
36
- def sort(self, layouts: list[Layout], size: tuple[int, int]) -> list[Layout]:
37
- width, height = size
38
- if width == 0 or height == 0:
39
- return layouts
40
-
41
- bbox_list = self._order_and_get_bbox_list(
42
- layouts=layouts,
43
- width=width,
44
- height=height,
45
- )
46
- if bbox_list is None:
47
- return layouts
48
-
49
- return self._sort_layouts_and_fragments(layouts, bbox_list)
50
-
51
- def _order_and_get_bbox_list(
52
- self,
53
- layouts: list[Layout],
54
- width: int,
55
- height: int,
56
- ) -> list[_BBox] | None:
57
-
58
- line_height = self._line_height(layouts)
59
- bbox_list: list[_BBox] = []
60
-
61
- for i, layout in enumerate(layouts):
62
- if layout.cls == LayoutClass.PLAIN_TEXT and \
63
- len(layout.fragments) > 0:
64
- for j, fragment in enumerate(layout.fragments):
65
- bbox_list.append(_BBox(
66
- layout_index=i,
67
- fragment_index=j,
68
- virtual=False,
69
- order=0,
70
- value=fragment.rect.wrapper,
71
- ))
72
- else:
73
- bbox_list.extend(
74
- self._generate_virtual_lines(
75
- layout=layout,
76
- layout_index=i,
77
- line_height=line_height,
78
- width=width,
79
- height=height,
80
- ),
81
- )
82
-
83
- if len(bbox_list) > 200:
84
- # https://github.com/opendatalab/MinerU/blob/980f5c8cd70f22f8c0c9b7b40eaff6f4804e6524/magic_pdf/pdf_parse_union_core_v2.py#L522
85
- return None
86
-
87
- layoutreader_size = 1000.0
88
- x_scale = layoutreader_size / float(width)
89
- y_scale = layoutreader_size / float(height)
90
-
91
- for bbox in bbox_list:
92
- x0, y0, x1, y1 = self._squeeze(bbox.value, width, height)
93
- x0 = round(x0 * x_scale)
94
- y0 = round(y0 * y_scale)
95
- x1 = round(x1 * x_scale)
96
- y1 = round(y1 * y_scale)
97
- bbox.value = (x0, y0, x1, y1)
98
-
99
- bbox_list.sort(key=lambda b: b.value) # 必须排序,乱序传入 layoutreader 会令它无法识别正确顺序
100
- model = self._get_model()
101
-
102
- with torch.no_grad():
103
- inputs = boxes2inputs([list(bbox.value) for bbox in bbox_list])
104
- inputs = prepare_inputs(inputs, model)
105
- logits = model(**inputs).logits.cpu().squeeze(0)
106
- orders = parse_logits(logits, len(bbox_list))
107
-
108
- sorted_bbox_list = [bbox_list[i] for i in orders]
109
- for i, bbox in enumerate(sorted_bbox_list):
110
- bbox.order = i
111
-
112
- return sorted_bbox_list
113
-
114
- def _sort_layouts_and_fragments(self, layouts: list[Layout], bbox_list: list[_BBox]):
115
- layout_bbox_list: list[list[_BBox]] = [[] for _ in range(len(layouts))]
116
- for bbox in bbox_list:
117
- layout_bbox_list[bbox.layout_index].append(bbox)
118
-
119
- layouts_with_median_order: list[tuple[Layout, float]] = []
120
- for layout_index, bbox_list in enumerate(layout_bbox_list):
121
- layout = layouts[layout_index]
122
- orders = [b.order for b in bbox_list] # virtual bbox 保证了 orders 不可能为空
123
- median_order = self._median(orders)
124
- layouts_with_median_order.append((layout, median_order))
125
-
126
- for layout, bbox_list in zip(layouts, layout_bbox_list):
127
- for bbox in bbox_list:
128
- if not bbox.virtual:
129
- layout.fragments[bbox.fragment_index].order = bbox.order
130
- if all(not bbox.virtual for bbox in bbox_list):
131
- layout.fragments.sort(key=lambda f: f.order)
132
-
133
- layouts_with_median_order.sort(key=lambda x: x[1])
134
- layouts = [layout for layout, _ in layouts_with_median_order]
135
- next_fragment_order: int = 0
136
-
137
- for layout in layouts:
138
- for fragment in layout.fragments:
139
- fragment.order = next_fragment_order
140
- next_fragment_order += 1
141
-
142
- return layouts
143
-
144
- def _line_height(self, layouts: list[Layout]) -> float:
145
- line_height: float = 0.0
146
- count: int = 0
147
- for layout in layouts:
148
- for fragment in layout.fragments:
149
- _, height = fragment.rect.size
150
- line_height += height
151
- count += 1
152
- if count == 0:
153
- return 10.0
154
- return line_height / float(count)
155
-
156
- def _generate_virtual_lines(
157
- self,
158
- layout: Layout,
159
- layout_index: int,
160
- line_height: float,
161
- width: int,
162
- height: int,
163
- ) -> Generator[_BBox, None, None]:
164
-
165
- # https://github.com/opendatalab/MinerU/blob/980f5c8cd70f22f8c0c9b7b40eaff6f4804e6524/magic_pdf/pdf_parse_union_core_v2.py#L451-L490
166
- x0, y0, x1, y1 = layout.rect.wrapper
167
- layout_height = y1 - y0
168
- layout_weight = x1 - x0
169
- lines = int(layout_height / line_height)
170
-
171
- if layout_height <= line_height * 2:
172
- yield _BBox(
173
- layout_index=layout_index,
174
- fragment_index=0,
175
- virtual=True,
176
- order=0,
177
- value=(x0, y0, x1, y1),
178
- )
179
- return
180
-
181
- elif layout_height <= height * 0.25 or \
182
- width * 0.5 <= layout_weight or \
183
- width * 0.25 < layout_weight:
184
- if layout_weight > width * 0.4:
185
- lines = 3
186
- elif layout_weight <= width * 0.25:
187
- if layout_height / layout_weight > 1.2: # 细长的不分
188
- yield _BBox(
189
- layout_index=layout_index,
190
- fragment_index=0,
191
- virtual=True,
192
- order=0,
193
- value=(x0, y0, x1, y1),
194
- )
195
- return
196
- else: # 不细长的还是分成两行
197
- lines = 2
198
-
199
- lines = max(1, lines)
200
- line_height = (y1 - y0) / lines
201
- current_y = y0
202
-
203
- for i in range(lines):
204
- yield _BBox(
205
- layout_index=layout_index,
206
- fragment_index=i,
207
- virtual=True,
208
- order=0,
209
- value=(x0, current_y, x1, current_y + line_height),
210
- )
211
- current_y += line_height
212
-
213
- def _median(self, numbers: list[int]) -> float:
214
- sorted_numbers = sorted(numbers)
215
- n = len(sorted_numbers)
216
-
217
- # 判断是奇数还是偶数个元素
218
- if n % 2 == 1:
219
- # 奇数情况,直接取中间的数
220
- return float(sorted_numbers[n // 2])
221
- else:
222
- # 偶数情况,取中间两个数的平均值
223
- mid1 = sorted_numbers[n // 2 - 1]
224
- mid2 = sorted_numbers[n // 2]
225
- return float((mid1 + mid2) / 2)
226
-
227
- def _squeeze(self, bbox: _BBox, width: int, height: int) -> _BBox:
228
- x0, y0, x1, y1 = bbox
229
- x0 = self._squeeze_value(x0, width)
230
- x1 = self._squeeze_value(x1, width)
231
- y0 = self._squeeze_value(y0, height)
232
- y1 = self._squeeze_value(y1, height)
233
- return x0, y0, x1, y1
234
-
235
- def _squeeze_value(self, position: float, size: int) -> float:
236
- if position < 0:
237
- position = 0.0
238
- if position > size:
239
- position = float(size)
240
- return position
@@ -1,126 +0,0 @@
1
- # Copy from https://github.com/ppaanngggg/layoutreader/blob/main/v3/helpers.py
2
- from collections import defaultdict
3
- from typing import List, Dict
4
-
5
- import torch
6
- from transformers import LayoutLMv3ForTokenClassification
7
-
8
- MAX_LEN = 510
9
- CLS_TOKEN_ID = 0
10
- UNK_TOKEN_ID = 3
11
- EOS_TOKEN_ID = 2
12
-
13
-
14
- class DataCollator:
15
- def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
16
- bbox = []
17
- labels = []
18
- input_ids = []
19
- attention_mask = []
20
-
21
- # clip bbox and labels to max length, build input_ids and attention_mask
22
- for feature in features:
23
- _bbox = feature["source_boxes"]
24
- if len(_bbox) > MAX_LEN:
25
- _bbox = _bbox[:MAX_LEN]
26
- _labels = feature["target_index"]
27
- if len(_labels) > MAX_LEN:
28
- _labels = _labels[:MAX_LEN]
29
- _input_ids = [UNK_TOKEN_ID] * len(_bbox)
30
- _attention_mask = [1] * len(_bbox)
31
- assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
32
- bbox.append(_bbox)
33
- labels.append(_labels)
34
- input_ids.append(_input_ids)
35
- attention_mask.append(_attention_mask)
36
-
37
- # add CLS and EOS tokens
38
- for i in range(len(bbox)):
39
- bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
40
- labels[i] = [-100] + labels[i] + [-100]
41
- input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
42
- attention_mask[i] = [1] + attention_mask[i] + [1]
43
-
44
- # padding to max length
45
- max_len = max(len(x) for x in bbox)
46
- for i in range(len(bbox)):
47
- bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
48
- labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
49
- input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
50
- attention_mask[i] = attention_mask[i] + [0] * (
51
- max_len - len(attention_mask[i])
52
- )
53
-
54
- ret = {
55
- "bbox": torch.tensor(bbox),
56
- "attention_mask": torch.tensor(attention_mask),
57
- "labels": torch.tensor(labels),
58
- "input_ids": torch.tensor(input_ids),
59
- }
60
- # set label > MAX_LEN to -100, because original labels may be > MAX_LEN
61
- ret["labels"][ret["labels"] > MAX_LEN] = -100
62
- # set label > 0 to label-1, because original labels are 1-indexed
63
- ret["labels"][ret["labels"] > 0] -= 1
64
- return ret
65
-
66
-
67
- def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
68
- bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
69
- input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
70
- attention_mask = [1] + [1] * len(boxes) + [1]
71
- return {
72
- "bbox": torch.tensor([bbox]),
73
- "attention_mask": torch.tensor([attention_mask]),
74
- "input_ids": torch.tensor([input_ids]),
75
- }
76
-
77
-
78
- def prepare_inputs(
79
- inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
80
- ) -> Dict[str, torch.Tensor]:
81
- ret = {}
82
- for k, v in inputs.items():
83
- v = v.to(model.device)
84
- if torch.is_floating_point(v):
85
- v = v.to(model.dtype)
86
- ret[k] = v
87
- return ret
88
-
89
-
90
- def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
91
- """
92
- parse logits to orders
93
-
94
- :param logits: logits from model
95
- :param length: input length
96
- :return: orders
97
- """
98
- logits = logits[1 : length + 1, :length]
99
- orders = logits.argsort(descending=False).tolist()
100
- ret = [o.pop() for o in orders]
101
- while True:
102
- order_to_idxes = defaultdict(list)
103
- for idx, order in enumerate(ret):
104
- order_to_idxes[order].append(idx)
105
- # filter idxes len > 1
106
- order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
107
- if not order_to_idxes:
108
- break
109
- # filter
110
- for order, idxes in order_to_idxes.items():
111
- # find original logits of idxes
112
- idxes_to_logit = {}
113
- for idx in idxes:
114
- idxes_to_logit[idx] = logits[idx, order]
115
- idxes_to_logit = sorted(
116
- idxes_to_logit.items(), key=lambda x: x[1], reverse=True
117
- )
118
- # keep the highest logit as order, set others to next candidate
119
- for idx, _ in idxes_to_logit[1:]:
120
- ret[idx] = orders[idx].pop()
121
-
122
- return ret
123
-
124
-
125
- def check_duplicate(a: List[int]) -> bool:
126
- return len(a) != len(set(a))
doc_page_extractor/ocr.py DELETED
@@ -1,175 +0,0 @@
1
- import numpy as np
2
- import cv2
3
- import os
4
-
5
- from typing import Literal, Generator
6
- from dataclasses import dataclass
7
- from .onnxocr import TextSystem
8
- from .types import OCRFragment
9
- from .rectangle import Rectangle
10
- from .downloader import download
11
- from .utils import is_space_text
12
-
13
-
14
- _MODELS = (
15
- ("ppocrv4", "rec", "rec.onnx"),
16
- ("ppocrv4", "cls", "cls.onnx"),
17
- ("ppocrv4", "det", "det.onnx"),
18
- ("ch_ppocr_server_v2.0", "ppocr_keys_v1.txt"),
19
- )
20
-
21
- @dataclass
22
- class _OONXParams:
23
- use_angle_cls: bool
24
- use_gpu: bool
25
- rec_image_shape: tuple[int, int, int]
26
- cls_image_shape: tuple[int, int, int]
27
- cls_batch_num: int
28
- cls_thresh: float
29
- label_list: list[str]
30
-
31
- det_algorithm: str
32
- det_limit_side_len: int
33
- det_limit_type: str
34
- det_db_thresh: float
35
- det_db_box_thresh: float
36
- det_db_unclip_ratio: float
37
- use_dilation: bool
38
- det_db_score_mode: str
39
- det_box_type: str
40
- rec_batch_num: int
41
- drop_score: float
42
- save_crop_res: bool
43
- rec_algorithm: str
44
- use_space_char: bool
45
- rec_model_dir: str
46
- cls_model_dir: str
47
- det_model_dir: str
48
- rec_char_dict_path: str
49
-
50
- class OCR:
51
- def __init__(
52
- self,
53
- device: Literal["cpu", "cuda"],
54
- model_dir_path: str,
55
- ):
56
- self._device: Literal["cpu", "cuda"] = device
57
- self._model_dir_path: str = model_dir_path
58
- self._text_system: TextSystem | None = None
59
-
60
- def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
61
- for box, res in self._ocr(image):
62
- text, rank = res
63
- if is_space_text(text):
64
- continue
65
-
66
- rect = Rectangle(
67
- lt=(box[0][0], box[0][1]),
68
- rt=(box[1][0], box[1][1]),
69
- rb=(box[2][0], box[2][1]),
70
- lb=(box[3][0], box[3][1]),
71
- )
72
- if not rect.is_valid or rect.area == 0.0:
73
- continue
74
-
75
- yield OCRFragment(
76
- order=0,
77
- text=text,
78
- rank=rank,
79
- rect=rect,
80
- )
81
-
82
- def _ocr(self, image: np.ndarray) -> Generator[tuple[list[list[float]], tuple[str, float]], None, None]:
83
- text_system = self._get_text_system()
84
- image = self._preprocess_image(image)
85
- dt_boxes, rec_res = text_system(image)
86
-
87
- for box, res in zip(dt_boxes, rec_res):
88
- yield box.tolist(), res
89
-
90
- def _get_text_system(self) -> TextSystem:
91
- if self._text_system is None:
92
- for model_path in _MODELS:
93
- file_path = os.path.join(self._model_dir_path, *model_path)
94
- if os.path.exists(file_path):
95
- continue
96
-
97
- file_dir_path = os.path.dirname(file_path)
98
- os.makedirs(file_dir_path, exist_ok=True)
99
-
100
- url_path = "/".join(model_path)
101
- url = f"https://huggingface.co/moskize/OnnxOCR/resolve/main/{url_path}"
102
- download(url, file_path)
103
-
104
- self._text_system = TextSystem(_OONXParams(
105
- use_angle_cls=True,
106
- use_gpu=(self._device != "cpu"),
107
- rec_image_shape=(3, 48, 320),
108
- cls_image_shape=(3, 48, 192),
109
- cls_batch_num=6,
110
- cls_thresh=0.9,
111
- label_list=["0", "180"],
112
- det_algorithm="DB",
113
- det_limit_side_len=960,
114
- det_limit_type="max",
115
- det_db_thresh=0.3,
116
- det_db_box_thresh=0.6,
117
- det_db_unclip_ratio=1.5,
118
- use_dilation=False,
119
- det_db_score_mode="fast",
120
- det_box_type="quad",
121
- rec_batch_num=6,
122
- drop_score=0.5,
123
- save_crop_res=False,
124
- rec_algorithm="SVTR_LCNet",
125
- use_space_char=True,
126
- rec_model_dir=os.path.join(self._model_dir_path, *_MODELS[0]),
127
- cls_model_dir=os.path.join(self._model_dir_path, *_MODELS[1]),
128
- det_model_dir=os.path.join(self._model_dir_path, *_MODELS[2]),
129
- rec_char_dict_path=os.path.join(self._model_dir_path, *_MODELS[3]),
130
- ))
131
-
132
- return self._text_system
133
-
134
- def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
135
- image = self._alpha_to_color(image, (255, 255, 255))
136
- # image = cv2.bitwise_not(image) # inv
137
- # image = self._binarize_img(image) # bin
138
- image = cv2.normalize(
139
- src=image,
140
- dst=np.zeros((image.shape[0], image.shape[1])),
141
- alpha=0,
142
- beta=255,
143
- norm_type=cv2.NORM_MINMAX,
144
- )
145
- image = cv2.fastNlMeansDenoisingColored(
146
- src=image,
147
- dst=None,
148
- h=10,
149
- hColor=10,
150
- templateWindowSize=7,
151
- searchWindowSize=15,
152
- )
153
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # image to gray
154
- return image
155
-
156
- def _alpha_to_color(self, image: np.ndarray, alpha_color: tuple[float, float, float]) -> np.ndarray:
157
- if len(image.shape) == 3 and image.shape[2] == 4:
158
- B, G, R, A = cv2.split(image)
159
- alpha = A / 255
160
-
161
- R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
162
- G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
163
- B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
164
-
165
- image = cv2.merge((B, G, R))
166
-
167
- return image
168
-
169
- def _binarize_img(self, image: np.ndarray):
170
- if len(image.shape) == 3 and image.shape[2] == 3:
171
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # conversion to grayscale image
172
- # use cv2 threshold binarization
173
- _, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
174
- image = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
175
- return image
@@ -1,126 +0,0 @@
1
- import numpy as np
2
-
3
- from typing import Iterable
4
- from shapely.geometry import Polygon
5
- from PIL.Image import new, Image, Resampling
6
- from .types import Layout, OCRFragment
7
- from .ocr import OCR
8
- from .overlap import overlap_rate
9
- from .rectangle import Point, Rectangle
10
-
11
-
12
- _MIN_RATE = 0.5
13
-
14
- def correct_fragments(ocr: OCR, source: Image, layout: Layout):
15
- x1, y1, x2, y2 = layout.rect.wrapper
16
- image: Image = source.crop((
17
- round(x1), round(y1),
18
- round(x2), round(y2),
19
- ))
20
- image, dx, dy, scale = _adjust_image(image)
21
- image_np = np.array(image)
22
- ocr_fragments = list(ocr.search_fragments(image_np))
23
- corrected_fragments: list[OCRFragment] = []
24
-
25
- for fragment in ocr_fragments:
26
- _apply_fragment(fragment.rect, layout, dx, dy, scale)
27
-
28
- matched_fragments, not_matched_fragments = _match_fragments(
29
- zone_rect=layout.rect,
30
- fragments1=layout.fragments,
31
- fragments2=ocr_fragments,
32
- )
33
- for fragment1, fragment2 in matched_fragments:
34
- if fragment1.rank > fragment2.rank:
35
- corrected_fragments.append(fragment1)
36
- else:
37
- corrected_fragments.append(fragment2)
38
-
39
- corrected_fragments.extend(not_matched_fragments)
40
- layout.fragments = corrected_fragments
41
-
42
- def _adjust_image(image: Image) -> tuple[Image, int, int, float]:
43
- # after testing, adding white borders to images can reduce
44
- # the possibility of some text not being recognized
45
- border_size: int = 50
46
- adjusted_size: int = 1024 - 2 * border_size
47
- width, height = image.size
48
- core_width = float(max(adjusted_size, width))
49
- core_height = float(max(adjusted_size, height))
50
-
51
- scale_x = core_width / width
52
- scale_y = core_height / height
53
- scale = min(scale_x, scale_y)
54
- adjusted_width = width * scale
55
- adjusted_height = height * scale
56
-
57
- dx = (core_width - adjusted_width) / 2.0
58
- dy = (core_height - adjusted_height) / 2.0
59
- dx = round(dx) + border_size
60
- dy = round(dy) + border_size
61
-
62
- if scale != 1.0:
63
- width = round(width * scale)
64
- height = round(height * scale)
65
- image = image.resize((width, height), Resampling.BICUBIC)
66
-
67
- width = round(core_width) + 2 * border_size
68
- height = round(core_height) + 2 * border_size
69
- new_image = new("RGB", (width, height), (255, 255, 255))
70
- new_image.paste(image, (dx, dy))
71
-
72
- return new_image, dx, dy, scale
73
-
74
- def _apply_fragment(rect: Rectangle, layout: Layout, dx: int, dy: int, scale: float):
75
- rect.lt = _apply_point(rect.lt, layout, dx, dy, scale)
76
- rect.lb = _apply_point(rect.lb, layout, dx, dy, scale)
77
- rect.rb = _apply_point(rect.rb, layout, dx, dy, scale)
78
- rect.rt = _apply_point(rect.rt, layout, dx, dy, scale)
79
-
80
- def _apply_point(point: Point, layout: Layout, dx: int, dy: int, scale: float) -> Point:
81
- x, y = point
82
- x = (x - dx) / scale + layout.rect.lt[0]
83
- y = (y - dy) / scale + layout.rect.lt[1]
84
- return x, y
85
-
86
- def _match_fragments(
87
- zone_rect: Rectangle,
88
- fragments1: Iterable[OCRFragment],
89
- fragments2: Iterable[OCRFragment],
90
- ) -> tuple[list[tuple[OCRFragment, OCRFragment]], list[OCRFragment]]:
91
-
92
- zone_polygon = Polygon(zone_rect)
93
- fragments2: list[OCRFragment] = list(fragments2)
94
- matched_fragments: list[tuple[OCRFragment, OCRFragment]] = []
95
- not_matched_fragments: list[OCRFragment] = []
96
-
97
- for fragment1 in fragments1:
98
- polygon1 = Polygon(fragment1.rect)
99
- polygon1 = zone_polygon.intersection(polygon1)
100
- if polygon1.is_empty:
101
- continue
102
-
103
- beast_j = -1
104
- beast_rate = 0.0
105
-
106
- for j, fragment2 in enumerate(fragments2):
107
- polygon2 = Polygon(fragment2.rect)
108
- rate = overlap_rate(polygon1, polygon2)
109
- if rate < _MIN_RATE:
110
- continue
111
-
112
- if rate > beast_rate:
113
- beast_j = j
114
- beast_rate = rate
115
-
116
- if beast_j != -1:
117
- matched_fragments.append((
118
- fragment1,
119
- fragments2[beast_j],
120
- ))
121
- del fragments2[beast_j]
122
- else:
123
- not_matched_fragments.append(fragment1)
124
-
125
- not_matched_fragments.extend(fragments2)
126
- return matched_fragments, not_matched_fragments
@@ -1 +0,0 @@
1
- from .predict_system import TextSystem
@@ -1,26 +0,0 @@
1
- class ClsPostProcess (object):
2
- """ Convert between text-label and text-index """
3
-
4
- def __init__(self, label_list=None, key=None, **kwargs):
5
- super(ClsPostProcess, self).__init__()
6
- self.label_list = label_list
7
- self.key = key
8
-
9
- def __call__(self, preds, label=None, *args, **kwargs):
10
- if self.key is not None:
11
- preds = preds[self.key]
12
-
13
- label_list = self.label_list
14
- if label_list is None:
15
- label_list = {idx: idx for idx in range(preds.shape[-1])}
16
-
17
- # if isinstance(preds, paddle.Tensor):
18
- # preds = preds.numpy()
19
-
20
- pred_idxs = preds.argmax(axis=1)
21
- decode_out = [(label_list[idx], preds[i, idx])
22
- for i, idx in enumerate(pred_idxs)]
23
- if label is None:
24
- return decode_out
25
- label = [(label_list[idx], 1.0) for idx in label]
26
- return decode_out, label