doc-page-extractor 0.0.5__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of doc-page-extractor might be problematic. Click here for more details.

Files changed (39) hide show
  1. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/PKG-INFO +17 -5
  2. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/README.md +14 -2
  3. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/__init__.py +1 -1
  4. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/downloader.py +4 -1
  5. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/extractor.py +7 -13
  6. doc_page_extractor-0.0.7/doc_page_extractor/ocr.py +172 -0
  7. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/ocr_corrector.py +3 -3
  8. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/__init__.py +1 -0
  9. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/cls_postprocess.py +26 -0
  10. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/db_postprocess.py +246 -0
  11. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/imaug.py +32 -0
  12. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/operators.py +187 -0
  13. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/predict_base.py +52 -0
  14. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/predict_cls.py +89 -0
  15. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/predict_det.py +120 -0
  16. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/predict_rec.py +321 -0
  17. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/predict_system.py +97 -0
  18. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/rec_postprocess.py +896 -0
  19. doc_page_extractor-0.0.7/doc_page_extractor/onnxocr/utils.py +71 -0
  20. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor.egg-info/PKG-INFO +17 -5
  21. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor.egg-info/SOURCES.txt +12 -0
  22. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor.egg-info/requires.txt +2 -2
  23. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/setup.py +3 -3
  24. doc_page_extractor-0.0.5/doc_page_extractor/ocr.py +0 -120
  25. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/LICENSE +0 -0
  26. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/clipper.py +0 -0
  27. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/layoutreader.py +0 -0
  28. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/overlap.py +0 -0
  29. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/plot.py +0 -0
  30. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/raw_optimizer.py +0 -0
  31. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/rectangle.py +0 -0
  32. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/rotation.py +0 -0
  33. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/types.py +0 -0
  34. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor/utils.py +0 -0
  35. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor.egg-info/dependency_links.txt +0 -0
  36. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/doc_page_extractor.egg-info/top_level.txt +0 -0
  37. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/setup.cfg +0 -0
  38. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/tests/__init__.py +0 -0
  39. {doc_page_extractor-0.0.5 → doc_page_extractor-0.0.7}/tests/test_history_bus.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: doc-page-extractor
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: doc page extractor can identify text and format in images and return structured data.
5
5
  Home-page: https://github.com/Moskize91/doc-page-extractor
6
6
  Author: Tao Zeyu
@@ -9,11 +9,11 @@ Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
10
  Requires-Dist: opencv-python<5.0,>=4.11.0
11
11
  Requires-Dist: pillow<11.0,>=10.3
12
- Requires-Dist: numpy<1.26,>=1.24.0
12
+ Requires-Dist: pyclipper<2.0,>=1.2.0
13
+ Requires-Dist: numpy<2.0,>=1.24.0
13
14
  Requires-Dist: shapely<3.0,>=2.0.0
14
15
  Requires-Dist: transformers<5.0,>=4.48.0
15
16
  Requires-Dist: doclayout_yolo>=0.0.3
16
- Requires-Dist: paddleocr==2.9.0
17
17
  Dynamic: author
18
18
  Dynamic: author-email
19
19
  Dynamic: description
@@ -36,10 +36,20 @@ doc page extractor can identify text and format in images and return structured
36
36
  pip install doc-page-extractor
37
37
  ```
38
38
 
39
+ ```shell
40
+ pip install onnxruntime==1.21.0
41
+ ```
42
+
39
43
  ## Using CUDA
40
44
 
41
45
  Please refer to the introduction of [PyTorch](https://pytorch.org/get-started/locally/) and select the appropriate command to install according to your operating system.
42
46
 
47
+ In addition, replace the command to install `onnxruntime` in the previous article with the following:
48
+
49
+ ```shell
50
+ pip install onnxruntime-gpu==1.21.0
51
+ ```
52
+
43
53
  ## Example
44
54
 
45
55
  ```python
@@ -48,7 +58,7 @@ from doc_page_extractor import DocExtractor
48
58
 
49
59
  extractor = DocExtractor(
50
60
  model_dir_path=model_path, # Folder address where AI model is downloaded and installed
51
- device="cpu", # If you want to use CUDA, please change to device="cuda:0".
61
+ device="cpu", # If you want to use CUDA, please change to device="cuda".
52
62
  )
53
63
  with Image.open("/path/to/your/image.png") as image:
54
64
  result = extractor.extract(
@@ -62,6 +72,8 @@ for layout in result.layouts:
62
72
 
63
73
  ## Acknowledgements
64
74
 
75
+ The code of `doc_page_extractor/onnxocr` in this repo comes from [OnnxOCR](https://github.com/jingsongliujing/OnnxOCR).
76
+
65
77
  - [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
66
- - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
78
+ - [OnnxOCR](https://github.com/jingsongliujing/OnnxOCR)
67
79
  - [layoutreader](https://github.com/ppaanngggg/layoutreader)
@@ -12,10 +12,20 @@ doc page extractor can identify text and format in images and return structured
12
12
  pip install doc-page-extractor
13
13
  ```
14
14
 
15
+ ```shell
16
+ pip install onnxruntime==1.21.0
17
+ ```
18
+
15
19
  ## Using CUDA
16
20
 
17
21
  Please refer to the introduction of [PyTorch](https://pytorch.org/get-started/locally/) and select the appropriate command to install according to your operating system.
18
22
 
23
+ In addition, replace the command to install `onnxruntime` in the previous article with the following:
24
+
25
+ ```shell
26
+ pip install onnxruntime-gpu==1.21.0
27
+ ```
28
+
19
29
  ## Example
20
30
 
21
31
  ```python
@@ -24,7 +34,7 @@ from doc_page_extractor import DocExtractor
24
34
 
25
35
  extractor = DocExtractor(
26
36
  model_dir_path=model_path, # Folder address where AI model is downloaded and installed
27
- device="cpu", # If you want to use CUDA, please change to device="cuda:0".
37
+ device="cpu", # If you want to use CUDA, please change to device="cuda".
28
38
  )
29
39
  with Image.open("/path/to/your/image.png") as image:
30
40
  result = extractor.extract(
@@ -38,6 +48,8 @@ for layout in result.layouts:
38
48
 
39
49
  ## Acknowledgements
40
50
 
51
+ The code of `doc_page_extractor/onnxocr` in this repo comes from [OnnxOCR](https://github.com/jingsongliujing/OnnxOCR).
52
+
41
53
  - [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
42
- - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
54
+ - [OnnxOCR](https://github.com/jingsongliujing/OnnxOCR)
43
55
  - [layoutreader](https://github.com/ppaanngggg/layoutreader)
@@ -1,4 +1,4 @@
1
- from .extractor import PaddleLang, DocExtractor
1
+ from .extractor import DocExtractor
2
2
  from .clipper import clip, clip_from_image
3
3
  from .plot import plot
4
4
  from .types import ExtractedResult, OCRFragment, LayoutClass, Layout
@@ -5,9 +5,12 @@ from pathlib import Path
5
5
 
6
6
  def download(url: str, file_path: Path):
7
7
  response = requests.get(url, stream=True, timeout=60)
8
+ if response.status_code != 200:
9
+ raise FileNotFoundError(f"Failed to download file from {url}: {response.status_code}")
8
10
  try:
9
11
  with open(file_path, "wb") as file:
10
12
  file.write(response.content)
11
13
  except Exception as e:
12
- os.remove(file_path)
14
+ if os.path.exists(file_path):
15
+ os.remove(file_path)
13
16
  raise e
@@ -1,5 +1,4 @@
1
1
  import os
2
- import torch
3
2
 
4
3
  from typing import Literal, Iterable
5
4
  from pathlib import Path
@@ -8,7 +7,7 @@ from transformers import LayoutLMv3ForTokenClassification
8
7
  from doclayout_yolo import YOLOv10
9
8
 
10
9
  from .layoutreader import prepare_inputs, boxes2inputs, parse_logits
11
- from .ocr import OCR, PaddleLang
10
+ from .ocr import OCR
12
11
  from .ocr_corrector import correct_fragments
13
12
  from .raw_optimizer import RawOptimizer
14
13
  from .rectangle import intersection_area, Rectangle
@@ -30,23 +29,18 @@ class DocExtractor:
30
29
  self._device: Literal["cpu", "cuda"] = device
31
30
  self._ocr_for_each_layouts: bool = ocr_for_each_layouts
32
31
  self._order_by_layoutreader: bool = order_by_layoutreader
33
- self._ocr: OCR = OCR(device, os.path.join(model_dir_path, "paddle"))
32
+ self._ocr: OCR = OCR(device, model_dir_path)
34
33
  self._yolo: YOLOv10 | None = None
35
34
  self._layout: LayoutLMv3ForTokenClassification | None = None
36
35
 
37
- if self._device.startswith("cuda") and not torch.cuda.is_available():
38
- self._device = "cpu"
39
- print("Warn: cuda is not available, use cpu instead")
40
-
41
36
  def extract(
42
37
  self,
43
38
  image: Image,
44
- lang: PaddleLang,
45
39
  adjust_points: bool = False,
46
40
  ) -> ExtractedResult:
47
41
 
48
42
  raw_optimizer = RawOptimizer(image, adjust_points)
49
- fragments = list(self._ocr.search_fragments(raw_optimizer.image_np, lang))
43
+ fragments = list(self._ocr.search_fragments(raw_optimizer.image_np))
50
44
  raw_optimizer.receive_raw_fragments(fragments)
51
45
 
52
46
  layouts = self._get_layouts(raw_optimizer.image)
@@ -54,7 +48,7 @@ class DocExtractor:
54
48
  layouts = remove_overlap_layouts(layouts)
55
49
 
56
50
  if self._ocr_for_each_layouts:
57
- self._correct_fragments_by_ocr_layouts(raw_optimizer.image, layouts, lang)
51
+ self._correct_fragments_by_ocr_layouts(raw_optimizer.image, layouts)
58
52
 
59
53
  if self._order_by_layoutreader:
60
54
  width, height = raw_optimizer.image.size
@@ -84,7 +78,7 @@ class DocExtractor:
84
78
  source=source,
85
79
  imgsz=1024,
86
80
  conf=0.2,
87
- device=self._device # Device to use (e.g., "cuda:0" or "cpu")
81
+ device=self._device # Device to use (e.g., "cuda" or "cpu")
88
82
  )
89
83
  boxes = det_res[0].__dict__["boxes"]
90
84
  layouts: list[Layout] = []
@@ -118,9 +112,9 @@ class DocExtractor:
118
112
  break
119
113
  return layouts
120
114
 
121
- def _correct_fragments_by_ocr_layouts(self, source: Image, layouts: list[Layout], lang: PaddleLang):
115
+ def _correct_fragments_by_ocr_layouts(self, source: Image, layouts: list[Layout]):
122
116
  for layout in layouts:
123
- correct_fragments(self._ocr, source, layout, lang)
117
+ correct_fragments(self._ocr, source, layout)
124
118
 
125
119
  def _split_layouts_by_group(self, layouts: list[Layout]):
126
120
  texts_layouts: list[Layout] = []
@@ -0,0 +1,172 @@
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+
5
+ from typing import Literal, Generator
6
+ from dataclasses import dataclass
7
+ from .onnxocr import TextSystem
8
+ from .types import OCRFragment
9
+ from .rectangle import Rectangle
10
+ from .downloader import download
11
+ from .utils import is_space_text
12
+
13
+
14
+ _MODELS = (
15
+ ("ppocrv4", "rec", "rec.onnx"),
16
+ ("ppocrv4", "cls", "cls.onnx"),
17
+ ("ppocrv4", "det", "det.onnx"),
18
+ ("ch_ppocr_server_v2.0", "ppocr_keys_v1.txt"),
19
+ )
20
+
21
+ @dataclass
22
+ class _OONXParams:
23
+ use_angle_cls: bool
24
+ use_gpu: bool
25
+ rec_image_shape: tuple[int, int, int]
26
+ cls_image_shape: tuple[int, int, int]
27
+ cls_batch_num: int
28
+ cls_thresh: float
29
+ label_list: list[str]
30
+
31
+ det_algorithm: str
32
+ det_limit_side_len: int
33
+ det_limit_type: str
34
+ det_db_thresh: float
35
+ det_db_box_thresh: float
36
+ det_db_unclip_ratio: float
37
+ use_dilation: bool
38
+ det_db_score_mode: str
39
+ det_box_type: str
40
+ rec_batch_num: int
41
+ drop_score: float
42
+ save_crop_res: bool
43
+ rec_algorithm: str
44
+ use_space_char: bool
45
+ rec_model_dir: str
46
+ cls_model_dir: str
47
+ det_model_dir: str
48
+ rec_char_dict_path: str
49
+
50
+ class OCR:
51
+ def __init__(
52
+ self,
53
+ device: Literal["cpu", "cuda"],
54
+ model_dir_path: str,
55
+ ):
56
+ self._device: Literal["cpu", "cuda"] = device
57
+ self._model_dir_path: str = model_dir_path
58
+ self._text_system: TextSystem | None = None
59
+
60
+ def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
61
+ index: int = 0
62
+ for box, res in self._ocr(image):
63
+ text, rank = res
64
+ if is_space_text(text):
65
+ continue
66
+ yield OCRFragment(
67
+ order=index,
68
+ text=text,
69
+ rank=rank,
70
+ rect=Rectangle(
71
+ lt=(box[0][0], box[0][1]),
72
+ rt=(box[1][0], box[1][1]),
73
+ rb=(box[2][0], box[2][1]),
74
+ lb=(box[3][0], box[3][1]),
75
+ ),
76
+ )
77
+ index += 1
78
+
79
+ def _ocr(self, image: np.ndarray) -> Generator[tuple[list[list[float]], tuple[str, float]], None, None]:
80
+ text_system = self._get_text_system()
81
+ image = self._preprocess_image(image)
82
+ dt_boxes, rec_res = text_system(image)
83
+
84
+ for box, res in zip(dt_boxes, rec_res):
85
+ yield box.tolist(), res
86
+
87
+ def _get_text_system(self) -> TextSystem:
88
+ if self._text_system is None:
89
+ for model_path in _MODELS:
90
+ file_path = os.path.join(self._model_dir_path, *model_path)
91
+ if os.path.exists(file_path):
92
+ continue
93
+
94
+ file_dir_path = os.path.dirname(file_path)
95
+ os.makedirs(file_dir_path, exist_ok=True)
96
+
97
+ url_path = "/".join(model_path)
98
+ url = f"https://huggingface.co/moskize/OnnxOCR/resolve/main/{url_path}"
99
+ download(url, file_path)
100
+
101
+ self._text_system = TextSystem(_OONXParams(
102
+ use_angle_cls=True,
103
+ use_gpu=(self._device != "cpu"),
104
+ rec_image_shape=(3, 48, 320),
105
+ cls_image_shape=(3, 48, 192),
106
+ cls_batch_num=6,
107
+ cls_thresh=0.9,
108
+ label_list=["0", "180"],
109
+ det_algorithm="DB",
110
+ det_limit_side_len=960,
111
+ det_limit_type="max",
112
+ det_db_thresh=0.3,
113
+ det_db_box_thresh=0.6,
114
+ det_db_unclip_ratio=1.5,
115
+ use_dilation=False,
116
+ det_db_score_mode="fast",
117
+ det_box_type="quad",
118
+ rec_batch_num=6,
119
+ drop_score=0.5,
120
+ save_crop_res=False,
121
+ rec_algorithm="SVTR_LCNet",
122
+ use_space_char=True,
123
+ rec_model_dir=os.path.join(self._model_dir_path, *_MODELS[0]),
124
+ cls_model_dir=os.path.join(self._model_dir_path, *_MODELS[1]),
125
+ det_model_dir=os.path.join(self._model_dir_path, *_MODELS[2]),
126
+ rec_char_dict_path=os.path.join(self._model_dir_path, *_MODELS[3]),
127
+ ))
128
+
129
+ return self._text_system
130
+
131
+ def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
132
+ image = self._alpha_to_color(image, (255, 255, 255))
133
+ # image = cv2.bitwise_not(image) # inv
134
+ # image = self._binarize_img(image) # bin
135
+ image = cv2.normalize(
136
+ src=image,
137
+ dst=np.zeros((image.shape[0], image.shape[1])),
138
+ alpha=0,
139
+ beta=255,
140
+ norm_type=cv2.NORM_MINMAX,
141
+ )
142
+ image = cv2.fastNlMeansDenoisingColored(
143
+ src=image,
144
+ dst=None,
145
+ h=10,
146
+ hColor=10,
147
+ templateWindowSize=7,
148
+ searchWindowSize=15,
149
+ )
150
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # image to gray
151
+ return image
152
+
153
+ def _alpha_to_color(self, image: np.ndarray, alpha_color: tuple[float, float, float]) -> np.ndarray:
154
+ if len(image.shape) == 3 and image.shape[2] == 4:
155
+ B, G, R, A = cv2.split(image)
156
+ alpha = A / 255
157
+
158
+ R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
159
+ G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
160
+ B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
161
+
162
+ image = cv2.merge((B, G, R))
163
+
164
+ return image
165
+
166
+ def _binarize_img(self, image: np.ndarray):
167
+ if len(image.shape) == 3 and image.shape[2] == 3:
168
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # conversion to grayscale image
169
+ # use cv2 threshold binarization
170
+ _, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
171
+ image = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
172
+ return image
@@ -4,14 +4,14 @@ from typing import Iterable
4
4
  from shapely.geometry import Polygon
5
5
  from PIL.Image import new, Image, Resampling
6
6
  from .types import Layout, OCRFragment
7
- from .ocr import OCR, PaddleLang
7
+ from .ocr import OCR
8
8
  from .overlap import overlap_rate
9
9
  from .rectangle import Point, Rectangle
10
10
 
11
11
 
12
12
  _MIN_RATE = 0.5
13
13
 
14
- def correct_fragments(ocr: OCR, source: Image, layout: Layout, lang: PaddleLang):
14
+ def correct_fragments(ocr: OCR, source: Image, layout: Layout):
15
15
  x1, y1, x2, y2 = layout.rect.wrapper
16
16
  image: Image = source.crop((
17
17
  round(x1), round(y1),
@@ -19,7 +19,7 @@ def correct_fragments(ocr: OCR, source: Image, layout: Layout, lang: PaddleLang)
19
19
  ))
20
20
  image, dx, dy, scale = _adjust_image(image)
21
21
  image_np = np.array(image)
22
- ocr_fragments = list(ocr.search_fragments(image_np, lang))
22
+ ocr_fragments = list(ocr.search_fragments(image_np))
23
23
  corrected_fragments: list[OCRFragment] = []
24
24
 
25
25
  for fragment in ocr_fragments:
@@ -0,0 +1 @@
1
+ from .predict_system import TextSystem
@@ -0,0 +1,26 @@
1
+ class ClsPostProcess (object):
2
+ """ Convert between text-label and text-index """
3
+
4
+ def __init__(self, label_list=None, key=None, **kwargs):
5
+ super(ClsPostProcess, self).__init__()
6
+ self.label_list = label_list
7
+ self.key = key
8
+
9
+ def __call__(self, preds, label=None, *args, **kwargs):
10
+ if self.key is not None:
11
+ preds = preds[self.key]
12
+
13
+ label_list = self.label_list
14
+ if label_list is None:
15
+ label_list = {idx: idx for idx in range(preds.shape[-1])}
16
+
17
+ # if isinstance(preds, paddle.Tensor):
18
+ # preds = preds.numpy()
19
+
20
+ pred_idxs = preds.argmax(axis=1)
21
+ decode_out = [(label_list[idx], preds[i, idx])
22
+ for i, idx in enumerate(pred_idxs)]
23
+ if label is None:
24
+ return decode_out
25
+ label = [(label_list[idx], 1.0) for idx in label]
26
+ return decode_out, label
@@ -0,0 +1,246 @@
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refered from:
16
+ https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
17
+ """
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import numpy as np
23
+ import cv2
24
+ # import paddle
25
+ from shapely.geometry import Polygon
26
+ import pyclipper
27
+
28
+
29
+ class DBPostProcess(object):
30
+ """
31
+ The post process for Differentiable Binarization (DB).
32
+ """
33
+
34
+ def __init__(self,
35
+ thresh=0.3,
36
+ box_thresh=0.7,
37
+ max_candidates=1000,
38
+ unclip_ratio=2.0,
39
+ use_dilation=False,
40
+ score_mode="fast",
41
+ box_type='quad',
42
+ **kwargs):
43
+ self.thresh = thresh
44
+ self.box_thresh = box_thresh
45
+ self.max_candidates = max_candidates
46
+ self.unclip_ratio = unclip_ratio
47
+ self.min_size = 3
48
+ self.score_mode = score_mode
49
+ self.box_type = box_type
50
+ assert score_mode in [
51
+ "slow", "fast"
52
+ ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
53
+
54
+ self.dilation_kernel = None if not use_dilation else np.array(
55
+ [[1, 1], [1, 1]])
56
+
57
+ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
58
+ '''
59
+ _bitmap: single map with shape (1, H, W),
60
+ whose values are binarized as {0, 1}
61
+ '''
62
+
63
+ bitmap = _bitmap
64
+ height, width = bitmap.shape
65
+
66
+ boxes = []
67
+ scores = []
68
+
69
+ contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
70
+ cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
71
+
72
+ for contour in contours[:self.max_candidates]:
73
+ epsilon = 0.002 * cv2.arcLength(contour, True)
74
+ approx = cv2.approxPolyDP(contour, epsilon, True)
75
+ points = approx.reshape((-1, 2))
76
+ if points.shape[0] < 4:
77
+ continue
78
+
79
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
80
+ if self.box_thresh > score:
81
+ continue
82
+
83
+ if points.shape[0] > 2:
84
+ box = self.unclip(points, self.unclip_ratio)
85
+ if len(box) > 1:
86
+ continue
87
+ else:
88
+ continue
89
+ box = box.reshape(-1, 2)
90
+
91
+ _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
92
+ if sside < self.min_size + 2:
93
+ continue
94
+
95
+ box = np.array(box)
96
+ box[:, 0] = np.clip(
97
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
98
+ box[:, 1] = np.clip(
99
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
100
+ boxes.append(box.tolist())
101
+ scores.append(score)
102
+ return boxes, scores
103
+
104
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
105
+ '''
106
+ _bitmap: single map with shape (1, H, W),
107
+ whose values are binarized as {0, 1}
108
+ '''
109
+
110
+ bitmap = _bitmap
111
+ height, width = bitmap.shape
112
+
113
+ outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
114
+ cv2.CHAIN_APPROX_SIMPLE)
115
+ if len(outs) == 3:
116
+ img, contours, _ = outs[0], outs[1], outs[2]
117
+ elif len(outs) == 2:
118
+ contours, _ = outs[0], outs[1]
119
+
120
+ num_contours = min(len(contours), self.max_candidates)
121
+
122
+ boxes = []
123
+ scores = []
124
+ for index in range(num_contours):
125
+ contour = contours[index]
126
+ points, sside = self.get_mini_boxes(contour)
127
+ if sside < self.min_size:
128
+ continue
129
+ points = np.array(points)
130
+ if self.score_mode == "fast":
131
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
132
+ else:
133
+ score = self.box_score_slow(pred, contour)
134
+ if self.box_thresh > score:
135
+ continue
136
+
137
+ box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
138
+ box, sside = self.get_mini_boxes(box)
139
+ if sside < self.min_size + 2:
140
+ continue
141
+ box = np.array(box)
142
+
143
+ box[:, 0] = np.clip(
144
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
145
+ box[:, 1] = np.clip(
146
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
147
+ boxes.append(box.astype("int32"))
148
+ scores.append(score)
149
+ return np.array(boxes, dtype="int32"), scores
150
+
151
+ def unclip(self, box, unclip_ratio):
152
+ poly = Polygon(box)
153
+ distance = poly.area * unclip_ratio / poly.length
154
+ offset = pyclipper.PyclipperOffset()
155
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
156
+ expanded = np.array(offset.Execute(distance))
157
+ return expanded
158
+
159
+ def get_mini_boxes(self, contour):
160
+ bounding_box = cv2.minAreaRect(contour)
161
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
162
+
163
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
164
+ if points[1][1] > points[0][1]:
165
+ index_1 = 0
166
+ index_4 = 1
167
+ else:
168
+ index_1 = 1
169
+ index_4 = 0
170
+ if points[3][1] > points[2][1]:
171
+ index_2 = 2
172
+ index_3 = 3
173
+ else:
174
+ index_2 = 3
175
+ index_3 = 2
176
+
177
+ box = [
178
+ points[index_1], points[index_2], points[index_3], points[index_4]
179
+ ]
180
+ return box, min(bounding_box[1])
181
+
182
+ def box_score_fast(self, bitmap, _box):
183
+ '''
184
+ box_score_fast: use bbox mean score as the mean score
185
+ '''
186
+ h, w = bitmap.shape[:2]
187
+ box = _box.copy()
188
+ xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
189
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
190
+ ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
191
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
192
+
193
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
194
+ box[:, 0] = box[:, 0] - xmin
195
+ box[:, 1] = box[:, 1] - ymin
196
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
197
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
198
+
199
+ def box_score_slow(self, bitmap, contour):
200
+ '''
201
+ box_score_slow: use polyon mean score as the mean score
202
+ '''
203
+ h, w = bitmap.shape[:2]
204
+ contour = contour.copy()
205
+ contour = np.reshape(contour, (-1, 2))
206
+
207
+ xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
208
+ xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
209
+ ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
210
+ ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
211
+
212
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
213
+
214
+ contour[:, 0] = contour[:, 0] - xmin
215
+ contour[:, 1] = contour[:, 1] - ymin
216
+
217
+ cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
218
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
219
+
220
+ def __call__(self, outs_dict, shape_list):
221
+ pred = outs_dict['maps']
222
+ # if isinstance(pred, paddle.Tensor):
223
+ # pred = pred.numpy()
224
+ pred = pred[:, 0, :, :]
225
+ segmentation = pred > self.thresh
226
+
227
+ boxes_batch = []
228
+ for batch_index in range(pred.shape[0]):
229
+ src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
230
+ if self.dilation_kernel is not None:
231
+ mask = cv2.dilate(
232
+ np.array(segmentation[batch_index]).astype(np.uint8),
233
+ self.dilation_kernel)
234
+ else:
235
+ mask = segmentation[batch_index]
236
+ if self.box_type == 'poly':
237
+ boxes, scores = self.polygons_from_bitmap(pred[batch_index],
238
+ mask, src_w, src_h)
239
+ elif self.box_type == 'quad':
240
+ boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
241
+ src_w, src_h)
242
+ else:
243
+ raise ValueError("box_type can only be one of ['quad', 'poly']")
244
+
245
+ boxes_batch.append({'points': boxes})
246
+ return boxes_batch