custom-layoutparser 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- custom_layoutparser-0.1.0.dist-info/METADATA +5 -0
- custom_layoutparser-0.1.0.dist-info/RECORD +36 -0
- custom_layoutparser-0.1.0.dist-info/WHEEL +5 -0
- custom_layoutparser-0.1.0.dist-info/top_level.txt +1 -0
- layoutparser/__init__.py +89 -0
- layoutparser/elements/__init__.py +25 -0
- layoutparser/elements/base.py +275 -0
- layoutparser/elements/errors.py +26 -0
- layoutparser/elements/layout.py +348 -0
- layoutparser/elements/layout_elements.py +1352 -0
- layoutparser/elements/utils.py +82 -0
- layoutparser/file_utils.py +235 -0
- layoutparser/io/__init__.py +2 -0
- layoutparser/io/basic.py +148 -0
- layoutparser/io/pdf.py +225 -0
- layoutparser/models/__init__.py +18 -0
- layoutparser/models/auto_layoutmodel.py +70 -0
- layoutparser/models/base_catalog.py +34 -0
- layoutparser/models/base_layoutmodel.py +88 -0
- layoutparser/models/detectron2/__init__.py +18 -0
- layoutparser/models/detectron2/catalog.py +142 -0
- layoutparser/models/detectron2/layoutmodel.py +168 -0
- layoutparser/models/effdet/__init__.py +16 -0
- layoutparser/models/effdet/catalog.py +88 -0
- layoutparser/models/effdet/layoutmodel.py +256 -0
- layoutparser/models/model_config.py +133 -0
- layoutparser/models/paddledetection/__init__.py +17 -0
- layoutparser/models/paddledetection/catalog.py +214 -0
- layoutparser/models/paddledetection/layoutmodel.py +297 -0
- layoutparser/ocr/__init__.py +16 -0
- layoutparser/ocr/base.py +41 -0
- layoutparser/ocr/gcv_agent.py +288 -0
- layoutparser/ocr/tesseract_agent.py +193 -0
- layoutparser/tools/__init__.py +5 -0
- layoutparser/tools/shape_operations.py +167 -0
- layoutparser/visualization.py +571 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
# Copyright 2021 The Layout Parser team and Paddle Detection model
|
|
2
|
+
# contributors. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import List, Union, Dict, Any, Tuple
|
|
17
|
+
import os
|
|
18
|
+
from functools import reduce
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
from PIL import Image
|
|
22
|
+
import cv2
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from .catalog import PathManager, LABEL_MAP_CATALOG, MODEL_CATALOG
|
|
26
|
+
from ..base_layoutmodel import BaseLayoutModel
|
|
27
|
+
from ...elements import Rectangle, TextBlock, Layout
|
|
28
|
+
|
|
29
|
+
from ...file_utils import is_paddle_available
|
|
30
|
+
|
|
31
|
+
if is_paddle_available():
|
|
32
|
+
import paddle.inference
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
__all__ = ["PaddleDetectionLayoutModel"]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _resize_image(
|
|
39
|
+
image: np.ndarray, target_size: Tuple[int, int]
|
|
40
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
41
|
+
"""
|
|
42
|
+
Args:
|
|
43
|
+
image (np.ndarray): image (np.ndarray)
|
|
44
|
+
Returns:
|
|
45
|
+
image (np.ndarray): processed image (np.ndarray)
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
origin_shape = image.shape[:2]
|
|
49
|
+
|
|
50
|
+
resize_h, resize_w = target_size
|
|
51
|
+
# im_scale_y: the resize ratio of Y
|
|
52
|
+
im_scale_y = resize_h / float(origin_shape[0])
|
|
53
|
+
# the resize ratio of X
|
|
54
|
+
im_scale_x = resize_w / float(origin_shape[1])
|
|
55
|
+
|
|
56
|
+
# resize image
|
|
57
|
+
image = cv2.resize(image, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2)
|
|
58
|
+
scale_factor = np.array([im_scale_y, im_scale_x]).astype("float32")
|
|
59
|
+
return image, scale_factor
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class PaddleDetectionLayoutModel(BaseLayoutModel):
|
|
63
|
+
"""Create a PaddleDetection-based Layout Detection Model
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
config_path (:obj:`str`):
|
|
67
|
+
The path to the configuration file.
|
|
68
|
+
model_path (:obj:`str`, None):
|
|
69
|
+
The path to the saved weights of the model.
|
|
70
|
+
If set, overwrite the weights in the configuration file.
|
|
71
|
+
Defaults to `None`.
|
|
72
|
+
label_map (:obj:`dict`, optional):
|
|
73
|
+
The map from the model prediction (ids) to real
|
|
74
|
+
word labels (strings). If the config is from one of the supported
|
|
75
|
+
datasets, Layout Parser will automatically initialize the label_map.
|
|
76
|
+
Defaults to `None`.
|
|
77
|
+
device(:obj:`str`, optional):
|
|
78
|
+
Whether to use cuda or cpu devices. If not set, LayoutParser will
|
|
79
|
+
automatically determine the device to initialize the models on.
|
|
80
|
+
extra_config (:obj:`dict`, optional):
|
|
81
|
+
Extra configuration passed to the PaddleDetection model configuration.
|
|
82
|
+
Defaults to `{}`.
|
|
83
|
+
Including arguments:
|
|
84
|
+
enable_mkldnn (:obj:`bool`, optional):
|
|
85
|
+
Whether use mkldnn to accelerate the computation.
|
|
86
|
+
Defaults to False.
|
|
87
|
+
thread_num (:obj:`int`, optional):
|
|
88
|
+
The number of threads.
|
|
89
|
+
Defaults to 10.
|
|
90
|
+
threshold (:obj:`float`, optional):
|
|
91
|
+
Threshold to reserve the result for output.
|
|
92
|
+
Defaults to 0.5.
|
|
93
|
+
target_size (:obj:`list`, optional):
|
|
94
|
+
The image shape after resize.
|
|
95
|
+
Defaults to [640,640].
|
|
96
|
+
|
|
97
|
+
Examples::
|
|
98
|
+
>>> import layoutparser as lp
|
|
99
|
+
>>> model = lp.models.PaddleDetectionLayoutModel('
|
|
100
|
+
lp://PubLayNet/ppyolov2_r50vd_dcn_365e/config')
|
|
101
|
+
>>> model.detect(image)
|
|
102
|
+
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
DEPENDENCIES = ["paddle"]
|
|
106
|
+
DETECTOR_NAME = "paddledetection"
|
|
107
|
+
MODEL_CATALOG = MODEL_CATALOG
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
config_path=None,
|
|
112
|
+
model_path=None,
|
|
113
|
+
label_map=None,
|
|
114
|
+
device=None,
|
|
115
|
+
enforce_cpu=None,
|
|
116
|
+
extra_config=None,
|
|
117
|
+
):
|
|
118
|
+
|
|
119
|
+
if enforce_cpu is not None:
|
|
120
|
+
warnings.warn(
|
|
121
|
+
"Setting enforce_cpu is deprecated. Please set `device` instead.",
|
|
122
|
+
DeprecationWarning,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if extra_config is None:
|
|
126
|
+
extra_config = {}
|
|
127
|
+
|
|
128
|
+
_, model_path = self.config_parser(config_path, model_path)
|
|
129
|
+
model_dir = PathManager.get_local_path(model_path)
|
|
130
|
+
|
|
131
|
+
if label_map is None:
|
|
132
|
+
if model_path.startswith("lp://"):
|
|
133
|
+
dataset_name = model_path.lstrip("lp://").split("/")[1]
|
|
134
|
+
label_map = LABEL_MAP_CATALOG[dataset_name]
|
|
135
|
+
else:
|
|
136
|
+
label_map = {}
|
|
137
|
+
|
|
138
|
+
self.label_map = label_map
|
|
139
|
+
|
|
140
|
+
# TODO: rethink how to save store the default constants
|
|
141
|
+
self.predictor = self.load_predictor(
|
|
142
|
+
model_dir,
|
|
143
|
+
device=device,
|
|
144
|
+
enable_mkldnn=extra_config.get("enable_mkldnn", False),
|
|
145
|
+
thread_num=extra_config.get("thread_num", 10),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.threshold = extra_config.get("threshold", 0.5)
|
|
149
|
+
self.target_size = extra_config.get("target_size", [640, 640])
|
|
150
|
+
self.pixel_mean = extra_config.get(
|
|
151
|
+
"pixel_mean", np.array([[[0.485, 0.456, 0.406]]])
|
|
152
|
+
)
|
|
153
|
+
self.pixel_std = extra_config.get(
|
|
154
|
+
"pixel_std", np.array([[[0.229, 0.224, 0.225]]])
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def load_predictor(
|
|
158
|
+
self,
|
|
159
|
+
model_dir,
|
|
160
|
+
device=None,
|
|
161
|
+
enable_mkldnn=False,
|
|
162
|
+
thread_num=10,
|
|
163
|
+
):
|
|
164
|
+
"""set AnalysisConfig, generate AnalysisPredictor
|
|
165
|
+
Args:
|
|
166
|
+
model_dir (str): root path of __model__ and __params__
|
|
167
|
+
device (str): cuda or cpu
|
|
168
|
+
Returns:
|
|
169
|
+
predictor (PaddlePredictor): AnalysisPredictor
|
|
170
|
+
Raises:
|
|
171
|
+
ValueError: predict by TensorRT need enforce_cpu == False.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
config = paddle.inference.Config(
|
|
175
|
+
os.path.join(
|
|
176
|
+
model_dir, "inference.pdmodel"
|
|
177
|
+
), # TODO: Move them to some constants
|
|
178
|
+
os.path.join(model_dir, "inference.pdiparams"),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if device == "cuda":
|
|
182
|
+
# initial GPU memory(M), device ID
|
|
183
|
+
# 2000 is an appropriate value for PaddleDetection model
|
|
184
|
+
config.enable_use_gpu(2000, 0)
|
|
185
|
+
# optimize graph and fuse op
|
|
186
|
+
config.switch_ir_optim(True)
|
|
187
|
+
else:
|
|
188
|
+
config.disable_gpu()
|
|
189
|
+
config.set_cpu_math_library_num_threads(thread_num)
|
|
190
|
+
if enable_mkldnn:
|
|
191
|
+
config.enable_mkldnn()
|
|
192
|
+
try:
|
|
193
|
+
# cache 10 different shapes for mkldnn to avoid memory leak
|
|
194
|
+
config.set_mkldnn_cache_capacity(10)
|
|
195
|
+
config.enable_mkldnn()
|
|
196
|
+
except Exception as e:
|
|
197
|
+
print(
|
|
198
|
+
"The current environment does not support `mkldnn`, so disable mkldnn."
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# disable print log when predict
|
|
202
|
+
config.disable_glog_info()
|
|
203
|
+
# enable shared memory
|
|
204
|
+
config.enable_memory_optim()
|
|
205
|
+
# disable feed, fetch OP, needed by zero_copy_run
|
|
206
|
+
config.switch_use_feed_fetch_ops(False)
|
|
207
|
+
predictor = paddle.inference.create_predictor(config)
|
|
208
|
+
return predictor
|
|
209
|
+
|
|
210
|
+
def preprocess(self, image):
|
|
211
|
+
"""preprocess image
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
image (np.ndarray): image (np.ndarray)
|
|
215
|
+
Returns:
|
|
216
|
+
inputs (dict): input of model
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
# resize image by target_size and max_size
|
|
220
|
+
image, scale_factor = _resize_image(image, self.target_size)
|
|
221
|
+
input_shape = np.array(image.shape[:2]).astype("float32")
|
|
222
|
+
# normalize image
|
|
223
|
+
image = (image / 255.0 - self.pixel_mean) / self.pixel_std
|
|
224
|
+
# transpose images
|
|
225
|
+
image = image.transpose((2, 0, 1)).copy()
|
|
226
|
+
|
|
227
|
+
inputs = {}
|
|
228
|
+
inputs["image"] = np.array(image)[np.newaxis, :].astype("float32")
|
|
229
|
+
inputs["im_shape"] = np.array(input_shape)[np.newaxis, :].astype("float32")
|
|
230
|
+
inputs["scale_factor"] = np.array(scale_factor)[np.newaxis, :].astype("float32")
|
|
231
|
+
return inputs
|
|
232
|
+
|
|
233
|
+
def gather_output(self, np_boxes):
|
|
234
|
+
"""process output"""
|
|
235
|
+
layout = Layout()
|
|
236
|
+
results = []
|
|
237
|
+
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
|
|
238
|
+
print("[WARNING] No object detected.")
|
|
239
|
+
results = {"boxes": np.array([])}
|
|
240
|
+
else:
|
|
241
|
+
results = {}
|
|
242
|
+
results["boxes"] = np_boxes
|
|
243
|
+
|
|
244
|
+
np_boxes = results["boxes"]
|
|
245
|
+
expect_boxes = (np_boxes[:, 1] > self.threshold) & (np_boxes[:, 0] > -1)
|
|
246
|
+
np_boxes = np_boxes[expect_boxes, :]
|
|
247
|
+
|
|
248
|
+
for np_box in np_boxes:
|
|
249
|
+
clsid, bbox, score = int(np_box[0]), np_box[2:], np_box[1]
|
|
250
|
+
x_1, y_1, x_2, y_2 = bbox
|
|
251
|
+
|
|
252
|
+
cur_block = TextBlock(
|
|
253
|
+
Rectangle(x_1, y_1, x_2, y_2),
|
|
254
|
+
type=self.label_map.get(clsid, clsid),
|
|
255
|
+
score=score,
|
|
256
|
+
)
|
|
257
|
+
layout.append(cur_block)
|
|
258
|
+
|
|
259
|
+
return layout
|
|
260
|
+
|
|
261
|
+
def detect(self, image):
|
|
262
|
+
"""Detect the layout of a given image.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
image (:obj:`np.ndarray` or `PIL.Image`): The input image to detect.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
:obj:`~layoutparser.Layout`: The detected layout of the input image
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
# Convert PIL Image Input
|
|
272
|
+
image = self.image_loader(image)
|
|
273
|
+
|
|
274
|
+
inputs = self.preprocess(image)
|
|
275
|
+
|
|
276
|
+
input_names = self.predictor.get_input_names()
|
|
277
|
+
|
|
278
|
+
for input_name in input_names:
|
|
279
|
+
input_tensor = self.predictor.get_input_handle(input_name)
|
|
280
|
+
input_tensor.copy_from_cpu(inputs[input_name])
|
|
281
|
+
|
|
282
|
+
self.predictor.run()
|
|
283
|
+
output_names = self.predictor.get_output_names()
|
|
284
|
+
boxes_tensor = self.predictor.get_output_handle(output_names[0])
|
|
285
|
+
np_boxes = boxes_tensor.copy_to_cpu()
|
|
286
|
+
|
|
287
|
+
layout = self.gather_output(np_boxes)
|
|
288
|
+
return layout
|
|
289
|
+
|
|
290
|
+
def image_loader(self, image: Union["np.ndarray", "Image.Image"]):
|
|
291
|
+
|
|
292
|
+
if isinstance(image, Image.Image):
|
|
293
|
+
if image.mode != "RGB":
|
|
294
|
+
image = image.convert("RGB")
|
|
295
|
+
image = np.array(image)
|
|
296
|
+
|
|
297
|
+
return image
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2021 The Layout Parser team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .gcv_agent import GCVAgent, GCVFeatureType
|
|
16
|
+
from .tesseract_agent import TesseractAgent, TesseractFeatureType
|
layoutparser/ocr/base.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright 2021 The Layout Parser team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from enum import IntEnum
|
|
17
|
+
|
|
18
|
+
from ..file_utils import requires_backends
|
|
19
|
+
|
|
20
|
+
class BaseOCRElementType(IntEnum):
|
|
21
|
+
@property
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def attr_name(self):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class BaseOCRAgent(ABC):
|
|
28
|
+
@property
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def DEPENDENCIES(self):
|
|
31
|
+
"""DEPENDENCIES lists all necessary dependencies for the class."""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
def __new__(cls, *args, **kwargs):
|
|
35
|
+
|
|
36
|
+
requires_backends(cls, cls.DEPENDENCIES)
|
|
37
|
+
return super().__new__(cls)
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def detect(self, image):
|
|
41
|
+
pass
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# Copyright 2021 The Layout Parser team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import io
|
|
16
|
+
import os
|
|
17
|
+
import json
|
|
18
|
+
import warnings
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
from cv2 import imencode
|
|
22
|
+
|
|
23
|
+
from .base import BaseOCRAgent, BaseOCRElementType
|
|
24
|
+
from ..elements import Layout, TextBlock, Quadrilateral, TextBlock
|
|
25
|
+
from ..file_utils import is_gcv_available
|
|
26
|
+
|
|
27
|
+
if is_gcv_available():
|
|
28
|
+
import google.protobuf.json_format as _json_format
|
|
29
|
+
import google.cloud.vision as _vision
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _cvt_GCV_vertices_to_points(vertices):
|
|
34
|
+
return np.array([[vertex.x, vertex.y] for vertex in vertices])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GCVFeatureType(BaseOCRElementType):
|
|
38
|
+
"""
|
|
39
|
+
The element types from Google Cloud Vision API
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
PAGE = 0
|
|
43
|
+
BLOCK = 1
|
|
44
|
+
PARA = 2
|
|
45
|
+
WORD = 3
|
|
46
|
+
SYMBOL = 4
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def attr_name(self):
|
|
50
|
+
name_cvt = {
|
|
51
|
+
GCVFeatureType.PAGE: "pages",
|
|
52
|
+
GCVFeatureType.BLOCK: "blocks",
|
|
53
|
+
GCVFeatureType.PARA: "paragraphs",
|
|
54
|
+
GCVFeatureType.WORD: "words",
|
|
55
|
+
GCVFeatureType.SYMBOL: "symbols",
|
|
56
|
+
}
|
|
57
|
+
return name_cvt[self]
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def child_level(self):
|
|
61
|
+
child_cvt = {
|
|
62
|
+
GCVFeatureType.PAGE: GCVFeatureType.BLOCK,
|
|
63
|
+
GCVFeatureType.BLOCK: GCVFeatureType.PARA,
|
|
64
|
+
GCVFeatureType.PARA: GCVFeatureType.WORD,
|
|
65
|
+
GCVFeatureType.WORD: GCVFeatureType.SYMBOL,
|
|
66
|
+
GCVFeatureType.SYMBOL: None,
|
|
67
|
+
}
|
|
68
|
+
return child_cvt[self]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class GCVAgent(BaseOCRAgent):
|
|
72
|
+
"""A wrapper for `Google Cloud Vision (GCV) <https://cloud.google.com/vision>`_ Text
|
|
73
|
+
Detection APIs.
|
|
74
|
+
|
|
75
|
+
Note:
|
|
76
|
+
Google Cloud Vision API returns the output text in two types:
|
|
77
|
+
|
|
78
|
+
* `text_annotations`:
|
|
79
|
+
|
|
80
|
+
In this format, GCV automatically find the best aggregation
|
|
81
|
+
level for the text, and return the results in a list. We use
|
|
82
|
+
:obj:`~gather_text_annotations` to reterive this type of
|
|
83
|
+
information.
|
|
84
|
+
|
|
85
|
+
* `full_text_annotation`:
|
|
86
|
+
|
|
87
|
+
To support better user control, GCV also provides the
|
|
88
|
+
`full_text_annotation` output, where it returns the hierarchical
|
|
89
|
+
structure of the output text. To process this output, we provide
|
|
90
|
+
the :obj:`~gather_full_text_annotation` function to aggregate the
|
|
91
|
+
texts of the given aggregation level.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
DEPENDENCIES = ["google-cloud-vision"]
|
|
95
|
+
|
|
96
|
+
def __init__(self, languages=None, ocr_image_decode_type=".png"):
|
|
97
|
+
"""Create a Google Cloud Vision OCR Agent.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
languages (:obj:`list`, optional):
|
|
101
|
+
You can specify the language code of the documents to detect to improve
|
|
102
|
+
accuracy. The supported language and their code can be found on `this page
|
|
103
|
+
<https://cloud.google.com/vision/docs/languages>`_.
|
|
104
|
+
Defaults to None.
|
|
105
|
+
|
|
106
|
+
ocr_image_decode_type (:obj:`str`, optional):
|
|
107
|
+
The format to convert the input image to before sending for GCV OCR.
|
|
108
|
+
Defaults to `".png"`.
|
|
109
|
+
|
|
110
|
+
* `".png"` is suggested as it does not compress the image.
|
|
111
|
+
* But `".jpg"` could also be a good choice if the input image is very large.
|
|
112
|
+
"""
|
|
113
|
+
try:
|
|
114
|
+
self._client = _vision.ImageAnnotatorClient()
|
|
115
|
+
except:
|
|
116
|
+
warnings.warn(
|
|
117
|
+
"The GCV credential has not been set. You could not run the detect command."
|
|
118
|
+
)
|
|
119
|
+
self._context = _vision.types.ImageContext(language_hints=languages)
|
|
120
|
+
self.ocr_image_decode_type = ocr_image_decode_type
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def with_credential(cls, credential_path, **kwargs):
|
|
124
|
+
"""Specifiy the credential to use for the GCV OCR API.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
credential_path (:obj:`str`): The path to the credential file
|
|
128
|
+
"""
|
|
129
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credential_path
|
|
130
|
+
return cls(**kwargs)
|
|
131
|
+
|
|
132
|
+
def _detect(self, img_content):
|
|
133
|
+
img_content = _vision.types.Image(content=img_content)
|
|
134
|
+
response = self._client.document_text_detection(
|
|
135
|
+
image=img_content, image_context=self._context
|
|
136
|
+
)
|
|
137
|
+
return response
|
|
138
|
+
|
|
139
|
+
def detect(
|
|
140
|
+
self,
|
|
141
|
+
image,
|
|
142
|
+
return_response=False,
|
|
143
|
+
return_only_text=False,
|
|
144
|
+
agg_output_level=None,
|
|
145
|
+
):
|
|
146
|
+
"""Send the input image for OCR.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
image (:obj:`np.ndarray` or :obj:`str`):
|
|
150
|
+
The input image array or the name of the image file
|
|
151
|
+
return_response (:obj:`bool`, optional):
|
|
152
|
+
Whether directly return the google cloud response.
|
|
153
|
+
Defaults to `False`.
|
|
154
|
+
return_only_text (:obj:`bool`, optional):
|
|
155
|
+
Whether return only the texts in the OCR results.
|
|
156
|
+
Defaults to `False`.
|
|
157
|
+
agg_output_level (:obj:`~GCVFeatureType`, optional):
|
|
158
|
+
When set, aggregate the GCV output with respect to the
|
|
159
|
+
specified aggregation level. Defaults to `None`.
|
|
160
|
+
"""
|
|
161
|
+
if isinstance(image, np.ndarray):
|
|
162
|
+
img_content = imencode(self.ocr_image_decode_type, image)[1].tostring()
|
|
163
|
+
|
|
164
|
+
elif isinstance(image, str):
|
|
165
|
+
with io.open(image, "rb") as image_file:
|
|
166
|
+
img_content = image_file.read()
|
|
167
|
+
|
|
168
|
+
res = self._detect(img_content)
|
|
169
|
+
|
|
170
|
+
if return_response:
|
|
171
|
+
return res
|
|
172
|
+
|
|
173
|
+
if return_only_text:
|
|
174
|
+
return res.full_text_annotation.text
|
|
175
|
+
|
|
176
|
+
if agg_output_level is not None:
|
|
177
|
+
return self.gather_full_text_annotation(res, agg_output_level)
|
|
178
|
+
|
|
179
|
+
return self.gather_text_annotations(res)
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def gather_text_annotations(response):
|
|
183
|
+
"""Convert the text_annotations from GCV output to an :obj:`Layout` object.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
response (:obj:`AnnotateImageResponse`):
|
|
187
|
+
The returned Google Cloud Vision AnnotateImageResponse object.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
:obj:`Layout`: The reterived layout from the response.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
# The 0th element contains all texts
|
|
194
|
+
doc = response.text_annotations[1:]
|
|
195
|
+
gathered_text = Layout()
|
|
196
|
+
|
|
197
|
+
for i, text_comp in enumerate(doc):
|
|
198
|
+
points = _cvt_GCV_vertices_to_points(text_comp.bounding_poly.vertices)
|
|
199
|
+
gathered_text.append(
|
|
200
|
+
TextBlock(block=Quadrilateral(points), text=text_comp.description, id=i)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
return gathered_text
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def gather_full_text_annotation(response, agg_level):
|
|
207
|
+
"""Convert the full_text_annotation from GCV output to an :obj:`Layout` object.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
response (:obj:`AnnotateImageResponse`):
|
|
211
|
+
The returned Google Cloud Vision AnnotateImageResponse object.
|
|
212
|
+
|
|
213
|
+
agg_level (:obj:`~GCVFeatureType`):
|
|
214
|
+
The layout level to aggregate the text in full_text_annotation.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
:obj:`Layout`: The reterived layout from the response.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
def iter_level(
|
|
221
|
+
iter,
|
|
222
|
+
agg_level=None,
|
|
223
|
+
text_blocks=None,
|
|
224
|
+
texts=None,
|
|
225
|
+
cur_level=GCVFeatureType.PAGE,
|
|
226
|
+
):
|
|
227
|
+
|
|
228
|
+
for item in getattr(iter, cur_level.attr_name):
|
|
229
|
+
if cur_level == agg_level:
|
|
230
|
+
texts = []
|
|
231
|
+
|
|
232
|
+
# Go down levels to fetch the texts
|
|
233
|
+
if cur_level == GCVFeatureType.SYMBOL:
|
|
234
|
+
texts.append(item.text)
|
|
235
|
+
elif (
|
|
236
|
+
cur_level == GCVFeatureType.WORD
|
|
237
|
+
and agg_level != GCVFeatureType.SYMBOL
|
|
238
|
+
):
|
|
239
|
+
chars = []
|
|
240
|
+
iter_level(
|
|
241
|
+
item, agg_level, text_blocks, chars, cur_level.child_level
|
|
242
|
+
)
|
|
243
|
+
texts.append("".join(chars))
|
|
244
|
+
else:
|
|
245
|
+
iter_level(
|
|
246
|
+
item, agg_level, text_blocks, texts, cur_level.child_level
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
if cur_level == agg_level:
|
|
250
|
+
nonlocal element_id
|
|
251
|
+
points = _cvt_GCV_vertices_to_points(item.bounding_box.vertices)
|
|
252
|
+
text_block = TextBlock(
|
|
253
|
+
block=Quadrilateral(points),
|
|
254
|
+
text=" ".join(texts),
|
|
255
|
+
score=item.confidence,
|
|
256
|
+
id=element_id,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
text_blocks.append(text_block)
|
|
260
|
+
element_id += 1
|
|
261
|
+
|
|
262
|
+
if agg_level == GCVFeatureType.PAGE:
|
|
263
|
+
doc = response.text_annotations[0]
|
|
264
|
+
points = _cvt_GCV_vertices_to_points(doc.bounding_poly.vertices)
|
|
265
|
+
|
|
266
|
+
text_blocks = [TextBlock(block=Quadrilateral(points), text=doc.description)]
|
|
267
|
+
|
|
268
|
+
else:
|
|
269
|
+
doc = response.full_text_annotation
|
|
270
|
+
text_blocks = []
|
|
271
|
+
element_id = 0
|
|
272
|
+
iter_level(doc, agg_level, text_blocks)
|
|
273
|
+
|
|
274
|
+
return Layout(text_blocks)
|
|
275
|
+
|
|
276
|
+
def load_response(self, filename):
|
|
277
|
+
with open(filename, "r") as f:
|
|
278
|
+
data = f.read()
|
|
279
|
+
return _json_format.Parse(
|
|
280
|
+
data, _vision.types.AnnotateImageResponse(), ignore_unknown_fields=True
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def save_response(self, res, file_name):
|
|
284
|
+
res = _json_format.MessageToJson(res)
|
|
285
|
+
|
|
286
|
+
with open(file_name, "w") as f:
|
|
287
|
+
json_file = json.loads(res)
|
|
288
|
+
json.dump(json_file, f)
|