onnxtr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/__init__.py +2 -0
- onnxtr/contrib/__init__.py +0 -0
- onnxtr/contrib/artefacts.py +131 -0
- onnxtr/contrib/base.py +105 -0
- onnxtr/file_utils.py +33 -0
- onnxtr/io/__init__.py +5 -0
- onnxtr/io/elements.py +455 -0
- onnxtr/io/html.py +28 -0
- onnxtr/io/image.py +56 -0
- onnxtr/io/pdf.py +42 -0
- onnxtr/io/reader.py +85 -0
- onnxtr/models/__init__.py +4 -0
- onnxtr/models/_utils.py +141 -0
- onnxtr/models/builder.py +355 -0
- onnxtr/models/classification/__init__.py +2 -0
- onnxtr/models/classification/models/__init__.py +1 -0
- onnxtr/models/classification/models/mobilenet.py +120 -0
- onnxtr/models/classification/predictor/__init__.py +1 -0
- onnxtr/models/classification/predictor/base.py +57 -0
- onnxtr/models/classification/zoo.py +76 -0
- onnxtr/models/detection/__init__.py +2 -0
- onnxtr/models/detection/core.py +101 -0
- onnxtr/models/detection/models/__init__.py +3 -0
- onnxtr/models/detection/models/differentiable_binarization.py +159 -0
- onnxtr/models/detection/models/fast.py +160 -0
- onnxtr/models/detection/models/linknet.py +160 -0
- onnxtr/models/detection/postprocessor/__init__.py +0 -0
- onnxtr/models/detection/postprocessor/base.py +144 -0
- onnxtr/models/detection/predictor/__init__.py +1 -0
- onnxtr/models/detection/predictor/base.py +54 -0
- onnxtr/models/detection/zoo.py +73 -0
- onnxtr/models/engine.py +50 -0
- onnxtr/models/predictor/__init__.py +1 -0
- onnxtr/models/predictor/base.py +175 -0
- onnxtr/models/predictor/predictor.py +145 -0
- onnxtr/models/preprocessor/__init__.py +1 -0
- onnxtr/models/preprocessor/base.py +118 -0
- onnxtr/models/recognition/__init__.py +2 -0
- onnxtr/models/recognition/core.py +28 -0
- onnxtr/models/recognition/models/__init__.py +5 -0
- onnxtr/models/recognition/models/crnn.py +226 -0
- onnxtr/models/recognition/models/master.py +145 -0
- onnxtr/models/recognition/models/parseq.py +134 -0
- onnxtr/models/recognition/models/sar.py +134 -0
- onnxtr/models/recognition/models/vitstr.py +166 -0
- onnxtr/models/recognition/predictor/__init__.py +1 -0
- onnxtr/models/recognition/predictor/_utils.py +86 -0
- onnxtr/models/recognition/predictor/base.py +79 -0
- onnxtr/models/recognition/utils.py +89 -0
- onnxtr/models/recognition/zoo.py +69 -0
- onnxtr/models/zoo.py +114 -0
- onnxtr/transforms/__init__.py +1 -0
- onnxtr/transforms/base.py +112 -0
- onnxtr/utils/__init__.py +4 -0
- onnxtr/utils/common_types.py +18 -0
- onnxtr/utils/data.py +126 -0
- onnxtr/utils/fonts.py +41 -0
- onnxtr/utils/geometry.py +498 -0
- onnxtr/utils/multithreading.py +50 -0
- onnxtr/utils/reconstitution.py +70 -0
- onnxtr/utils/repr.py +64 -0
- onnxtr/utils/visualization.py +291 -0
- onnxtr/utils/vocabs.py +71 -0
- onnxtr/version.py +1 -0
- onnxtr-0.1.0.dist-info/LICENSE +201 -0
- onnxtr-0.1.0.dist-info/METADATA +481 -0
- onnxtr-0.1.0.dist-info/RECORD +70 -0
- onnxtr-0.1.0.dist-info/WHEEL +5 -0
- onnxtr-0.1.0.dist-info/top_level.txt +2 -0
- onnxtr-0.1.0.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
|
+
|
|
8
|
+
from typing import List, Union
|
|
9
|
+
|
|
10
|
+
import cv2
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pyclipper
|
|
13
|
+
from shapely.geometry import Polygon
|
|
14
|
+
|
|
15
|
+
from ..core import DetectionPostProcessor
|
|
16
|
+
|
|
17
|
+
__all__ = ["GeneralDetectionPostProcessor"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GeneralDetectionPostProcessor(DetectionPostProcessor):
|
|
21
|
+
"""Implements a post processor for FAST model.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
----
|
|
25
|
+
bin_thresh: threshold used to binzarized p_map at inference time
|
|
26
|
+
box_thresh: minimal objectness score to consider a box
|
|
27
|
+
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
bin_thresh: float = 0.1,
|
|
33
|
+
box_thresh: float = 0.1,
|
|
34
|
+
assume_straight_pages: bool = True,
|
|
35
|
+
) -> None:
|
|
36
|
+
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
|
|
37
|
+
self.unclip_ratio = 1.5
|
|
38
|
+
|
|
39
|
+
def polygon_to_box(
|
|
40
|
+
self,
|
|
41
|
+
points: np.ndarray,
|
|
42
|
+
) -> np.ndarray:
|
|
43
|
+
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
----
|
|
47
|
+
points: The first parameter.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
-------
|
|
51
|
+
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
52
|
+
"""
|
|
53
|
+
if not self.assume_straight_pages:
|
|
54
|
+
# Compute the rectangle polygon enclosing the raw polygon
|
|
55
|
+
rect = cv2.minAreaRect(points)
|
|
56
|
+
points = cv2.boxPoints(rect)
|
|
57
|
+
# Add 1 pixel to correct cv2 approx
|
|
58
|
+
area = (rect[1][0] + 1) * (1 + rect[1][1])
|
|
59
|
+
length = 2 * (rect[1][0] + rect[1][1]) + 2
|
|
60
|
+
else:
|
|
61
|
+
poly = Polygon(points)
|
|
62
|
+
area = poly.area
|
|
63
|
+
length = poly.length
|
|
64
|
+
distance = area * self.unclip_ratio / length # compute distance to expand polygon
|
|
65
|
+
offset = pyclipper.PyclipperOffset()
|
|
66
|
+
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
67
|
+
_points = offset.Execute(distance)
|
|
68
|
+
# Take biggest stack of points
|
|
69
|
+
idx = 0
|
|
70
|
+
if len(_points) > 1:
|
|
71
|
+
max_size = 0
|
|
72
|
+
for _idx, p in enumerate(_points):
|
|
73
|
+
if len(p) > max_size:
|
|
74
|
+
idx = _idx
|
|
75
|
+
max_size = len(p)
|
|
76
|
+
# We ensure that _points can be correctly casted to a ndarray
|
|
77
|
+
_points = [_points[idx]]
|
|
78
|
+
expanded_points: np.ndarray = np.asarray(_points) # expand polygon
|
|
79
|
+
if len(expanded_points) < 1:
|
|
80
|
+
return None # type: ignore[return-value]
|
|
81
|
+
return (
|
|
82
|
+
cv2.boundingRect(expanded_points) # type: ignore[return-value]
|
|
83
|
+
if self.assume_straight_pages
|
|
84
|
+
else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def bitmap_to_boxes(
|
|
88
|
+
self,
|
|
89
|
+
pred: np.ndarray,
|
|
90
|
+
bitmap: np.ndarray,
|
|
91
|
+
) -> np.ndarray:
|
|
92
|
+
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
----
|
|
96
|
+
pred: Pred map from differentiable linknet output
|
|
97
|
+
bitmap: Bitmap map computed from pred (binarized)
|
|
98
|
+
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
99
|
+
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
-------
|
|
103
|
+
np tensor boxes for the bitmap, each box is a 6-element list
|
|
104
|
+
containing x, y, w, h, alpha, score for the box
|
|
105
|
+
"""
|
|
106
|
+
height, width = bitmap.shape[:2]
|
|
107
|
+
boxes: List[Union[np.ndarray, List[float]]] = []
|
|
108
|
+
# get contours from connected components on the bitmap
|
|
109
|
+
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
110
|
+
for contour in contours:
|
|
111
|
+
# Check whether smallest enclosing bounding box is not too small
|
|
112
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
113
|
+
continue
|
|
114
|
+
# Compute objectness
|
|
115
|
+
if self.assume_straight_pages:
|
|
116
|
+
x, y, w, h = cv2.boundingRect(contour)
|
|
117
|
+
points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
|
|
118
|
+
score = self.box_score(pred, points, assume_straight_pages=True)
|
|
119
|
+
else:
|
|
120
|
+
score = self.box_score(pred, contour, assume_straight_pages=False)
|
|
121
|
+
|
|
122
|
+
if score < self.box_thresh: # remove polygons with a weak objectness
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
if self.assume_straight_pages:
|
|
126
|
+
_box = self.polygon_to_box(points)
|
|
127
|
+
else:
|
|
128
|
+
_box = self.polygon_to_box(np.squeeze(contour))
|
|
129
|
+
|
|
130
|
+
if self.assume_straight_pages:
|
|
131
|
+
# compute relative polygon to get rid of img shape
|
|
132
|
+
x, y, w, h = _box
|
|
133
|
+
xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
|
|
134
|
+
boxes.append([xmin, ymin, xmax, ymax, score])
|
|
135
|
+
else:
|
|
136
|
+
# compute relative box to get rid of img shape
|
|
137
|
+
_box[:, 0] /= width
|
|
138
|
+
_box[:, 1] /= height
|
|
139
|
+
boxes.append(_box)
|
|
140
|
+
|
|
141
|
+
if not self.assume_straight_pages:
|
|
142
|
+
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
|
|
143
|
+
else:
|
|
144
|
+
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import *
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, List, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from onnxtr.models.preprocessor import PreProcessor
|
|
11
|
+
from onnxtr.utils.repr import NestedObject
|
|
12
|
+
|
|
13
|
+
__all__ = ["DetectionPredictor"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DetectionPredictor(NestedObject):
|
|
17
|
+
"""Implements an object able to localize text elements in a document
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
----
|
|
21
|
+
pre_processor: transform inputs for easier batched model inference
|
|
22
|
+
model: core detection architecture
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
_children_names: List[str] = ["pre_processor", "model"]
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
pre_processor: PreProcessor,
|
|
30
|
+
model: Any,
|
|
31
|
+
) -> None:
|
|
32
|
+
self.pre_processor = pre_processor
|
|
33
|
+
self.model = model
|
|
34
|
+
|
|
35
|
+
def __call__(
|
|
36
|
+
self,
|
|
37
|
+
pages: List[np.ndarray],
|
|
38
|
+
return_maps: bool = False,
|
|
39
|
+
**kwargs: Any,
|
|
40
|
+
) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]:
|
|
41
|
+
# Dimension check
|
|
42
|
+
if any(page.ndim != 3 for page in pages):
|
|
43
|
+
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
44
|
+
|
|
45
|
+
processed_batches = self.pre_processor(pages)
|
|
46
|
+
predicted_batches = [
|
|
47
|
+
self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
preds = [pred for batch in predicted_batches for pred in batch["preds"]]
|
|
51
|
+
if return_maps:
|
|
52
|
+
seg_maps = [pred for batch in predicted_batches for pred in batch["out_map"]]
|
|
53
|
+
return preds, seg_maps
|
|
54
|
+
return preds
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from .. import detection
|
|
9
|
+
from ..preprocessor import PreProcessor
|
|
10
|
+
from .predictor import DetectionPredictor
|
|
11
|
+
|
|
12
|
+
__all__ = ["detection_predictor"]
|
|
13
|
+
|
|
14
|
+
ARCHS = [
|
|
15
|
+
"db_resnet34",
|
|
16
|
+
"db_resnet50",
|
|
17
|
+
"db_mobilenet_v3_large",
|
|
18
|
+
"linknet_resnet18",
|
|
19
|
+
"linknet_resnet34",
|
|
20
|
+
"linknet_resnet50",
|
|
21
|
+
"fast_tiny",
|
|
22
|
+
"fast_small",
|
|
23
|
+
"fast_base",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _predictor(arch: Any, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
|
|
28
|
+
if isinstance(arch, str):
|
|
29
|
+
if arch not in ARCHS:
|
|
30
|
+
raise ValueError(f"unknown architecture '{arch}'")
|
|
31
|
+
|
|
32
|
+
_model = detection.__dict__[arch](assume_straight_pages=assume_straight_pages)
|
|
33
|
+
else:
|
|
34
|
+
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
|
|
35
|
+
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
36
|
+
|
|
37
|
+
_model = arch
|
|
38
|
+
_model.postprocessor.assume_straight_pages = assume_straight_pages
|
|
39
|
+
|
|
40
|
+
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
41
|
+
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
42
|
+
kwargs["batch_size"] = kwargs.get("batch_size", 4)
|
|
43
|
+
predictor = DetectionPredictor(
|
|
44
|
+
PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
|
|
45
|
+
_model,
|
|
46
|
+
)
|
|
47
|
+
return predictor
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def detection_predictor(
|
|
51
|
+
arch: Any = "fast_base",
|
|
52
|
+
assume_straight_pages: bool = True,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
) -> DetectionPredictor:
|
|
55
|
+
"""Text detection architecture.
|
|
56
|
+
|
|
57
|
+
>>> import numpy as np
|
|
58
|
+
>>> from onnxtr.models import detection_predictor
|
|
59
|
+
>>> model = detection_predictor(arch='db_resnet50')
|
|
60
|
+
>>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
|
|
61
|
+
>>> out = model([input_page])
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
----
|
|
65
|
+
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
|
|
66
|
+
assume_straight_pages: If True, fit straight boxes to the page
|
|
67
|
+
**kwargs: optional keyword arguments passed to the architecture
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
-------
|
|
71
|
+
Detection predictor
|
|
72
|
+
"""
|
|
73
|
+
return _predictor(arch, assume_straight_pages, **kwargs)
|
onnxtr/models/engine.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, List, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import onnxruntime
|
|
10
|
+
|
|
11
|
+
from onnxtr.utils.data import download_from_url
|
|
12
|
+
from onnxtr.utils.geometry import shape_translate
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Engine:
|
|
16
|
+
"""Implements an abstract class for the engine of a model
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
----
|
|
20
|
+
url: the url to use to download a model if needed
|
|
21
|
+
providers: list of providers to use for inference
|
|
22
|
+
**kwargs: additional arguments to be passed to `download_from_url`
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self, url: str, providers: List[str] = ["CPUExecutionProvider", "CUDAExecutionProvider"], **kwargs: Any
|
|
27
|
+
) -> None:
|
|
28
|
+
archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
|
|
29
|
+
self.runtime = onnxruntime.InferenceSession(archive_path, providers=providers)
|
|
30
|
+
self.runtime_inputs = self.runtime.get_inputs()[0]
|
|
31
|
+
self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3
|
|
32
|
+
self.fixed_batch_size: Union[int, str] = self.runtime_inputs.shape[
|
|
33
|
+
0
|
|
34
|
+
] # mostly possible with tensorflow exported models
|
|
35
|
+
self.output_name = [output.name for output in self.runtime.get_outputs()]
|
|
36
|
+
|
|
37
|
+
def run(self, inputs: np.ndarray) -> np.ndarray:
|
|
38
|
+
if self.tf_exported:
|
|
39
|
+
inputs = shape_translate(inputs, format="BHWC") # sanity check
|
|
40
|
+
else:
|
|
41
|
+
inputs = shape_translate(inputs, format="BCHW")
|
|
42
|
+
if isinstance(self.fixed_batch_size, int) and self.fixed_batch_size != 0: # dynamic batch size is a string
|
|
43
|
+
inputs = np.broadcast_to(inputs, (self.fixed_batch_size, *inputs.shape))
|
|
44
|
+
# combine the results
|
|
45
|
+
logits = np.concatenate(
|
|
46
|
+
[self.runtime.run(self.output_name, {"input": batch})[0] for batch in inputs], axis=0
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
logits = self.runtime.run(self.output_name, {"input": inputs})[0]
|
|
50
|
+
return shape_translate(logits, format="BHWC")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .predictor import *
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from onnxtr.models.builder import DocumentBuilder
|
|
11
|
+
from onnxtr.utils.geometry import extract_crops, extract_rcrops
|
|
12
|
+
|
|
13
|
+
from .._utils import rectify_crops, rectify_loc_preds
|
|
14
|
+
from ..classification import crop_orientation_predictor
|
|
15
|
+
from ..classification.predictor import OrientationPredictor
|
|
16
|
+
from ..detection.zoo import ARCHS as DETECTION_ARCHS
|
|
17
|
+
from ..recognition.zoo import ARCHS as RECOGNITION_ARCHS
|
|
18
|
+
|
|
19
|
+
__all__ = ["_OCRPredictor"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _OCRPredictor:
|
|
23
|
+
"""Implements an object able to localize and identify text elements in a set of documents
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
----
|
|
27
|
+
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
28
|
+
without rotated textual elements.
|
|
29
|
+
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
|
|
30
|
+
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
|
|
31
|
+
accordingly. Doing so will improve performances for documents with page-uniform rotations.
|
|
32
|
+
preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
|
|
33
|
+
symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
|
|
34
|
+
**kwargs: keyword args of `DocumentBuilder`
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
crop_orientation_predictor: Optional[OrientationPredictor]
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
assume_straight_pages: bool = True,
|
|
42
|
+
straighten_pages: bool = False,
|
|
43
|
+
preserve_aspect_ratio: bool = True,
|
|
44
|
+
symmetric_pad: bool = True,
|
|
45
|
+
**kwargs: Any,
|
|
46
|
+
) -> None:
|
|
47
|
+
self.assume_straight_pages = assume_straight_pages
|
|
48
|
+
self.straighten_pages = straighten_pages
|
|
49
|
+
self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor()
|
|
50
|
+
self.doc_builder = DocumentBuilder(**kwargs)
|
|
51
|
+
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
52
|
+
self.symmetric_pad = symmetric_pad
|
|
53
|
+
self.hooks: List[Callable] = []
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def _generate_crops(
|
|
57
|
+
pages: List[np.ndarray],
|
|
58
|
+
loc_preds: List[np.ndarray],
|
|
59
|
+
channels_last: bool,
|
|
60
|
+
assume_straight_pages: bool = False,
|
|
61
|
+
) -> List[List[np.ndarray]]:
|
|
62
|
+
extraction_fn = extract_crops if assume_straight_pages else extract_rcrops
|
|
63
|
+
|
|
64
|
+
crops = [
|
|
65
|
+
extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator]
|
|
66
|
+
for page, _boxes in zip(pages, loc_preds)
|
|
67
|
+
]
|
|
68
|
+
return crops
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def _prepare_crops(
|
|
72
|
+
pages: List[np.ndarray],
|
|
73
|
+
loc_preds: List[np.ndarray],
|
|
74
|
+
channels_last: bool,
|
|
75
|
+
assume_straight_pages: bool = False,
|
|
76
|
+
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
|
|
77
|
+
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
|
|
78
|
+
|
|
79
|
+
# Avoid sending zero-sized crops
|
|
80
|
+
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
|
|
81
|
+
crops = [
|
|
82
|
+
[crop for crop, _kept in zip(page_crops, page_kept) if _kept]
|
|
83
|
+
for page_crops, page_kept in zip(crops, is_kept)
|
|
84
|
+
]
|
|
85
|
+
loc_preds = [_boxes[_kept] for _boxes, _kept in zip(loc_preds, is_kept)]
|
|
86
|
+
|
|
87
|
+
return crops, loc_preds
|
|
88
|
+
|
|
89
|
+
def _rectify_crops(
|
|
90
|
+
self,
|
|
91
|
+
crops: List[List[np.ndarray]],
|
|
92
|
+
loc_preds: List[np.ndarray],
|
|
93
|
+
) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]:
|
|
94
|
+
# Work at a page level
|
|
95
|
+
orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
|
|
96
|
+
rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
|
|
97
|
+
rect_loc_preds = [
|
|
98
|
+
rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds
|
|
99
|
+
for page_loc_preds, orientation in zip(loc_preds, orientations)
|
|
100
|
+
]
|
|
101
|
+
# Flatten to list of tuples with (value, confidence)
|
|
102
|
+
crop_orientations = [
|
|
103
|
+
(orientation, prob)
|
|
104
|
+
for page_classes, page_probs in zip(classes, probs)
|
|
105
|
+
for orientation, prob in zip(page_classes, page_probs)
|
|
106
|
+
]
|
|
107
|
+
return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
|
|
108
|
+
|
|
109
|
+
def _remove_padding(
|
|
110
|
+
self,
|
|
111
|
+
pages: List[np.ndarray],
|
|
112
|
+
loc_preds: List[np.ndarray],
|
|
113
|
+
) -> List[np.ndarray]:
|
|
114
|
+
if self.preserve_aspect_ratio:
|
|
115
|
+
# Rectify loc_preds to remove padding
|
|
116
|
+
rectified_preds = []
|
|
117
|
+
for page, loc_pred in zip(pages, loc_preds):
|
|
118
|
+
h, w = page.shape[0], page.shape[1]
|
|
119
|
+
if h > w:
|
|
120
|
+
# y unchanged, dilate x coord
|
|
121
|
+
if self.symmetric_pad:
|
|
122
|
+
if self.assume_straight_pages:
|
|
123
|
+
loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
|
|
124
|
+
else:
|
|
125
|
+
loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
|
|
126
|
+
else:
|
|
127
|
+
if self.assume_straight_pages:
|
|
128
|
+
loc_pred[:, [0, 2]] *= h / w
|
|
129
|
+
else:
|
|
130
|
+
loc_pred[:, :, 0] *= h / w
|
|
131
|
+
elif w > h:
|
|
132
|
+
# x unchanged, dilate y coord
|
|
133
|
+
if self.symmetric_pad:
|
|
134
|
+
if self.assume_straight_pages:
|
|
135
|
+
loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
|
|
136
|
+
else:
|
|
137
|
+
loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
|
|
138
|
+
else:
|
|
139
|
+
if self.assume_straight_pages:
|
|
140
|
+
loc_pred[:, [1, 3]] *= w / h
|
|
141
|
+
else:
|
|
142
|
+
loc_pred[:, :, 1] *= w / h
|
|
143
|
+
rectified_preds.append(loc_pred)
|
|
144
|
+
return rectified_preds
|
|
145
|
+
return loc_preds
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _process_predictions(
|
|
149
|
+
loc_preds: List[np.ndarray],
|
|
150
|
+
word_preds: List[Tuple[str, float]],
|
|
151
|
+
crop_orientations: List[Dict[str, Any]],
|
|
152
|
+
) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]:
|
|
153
|
+
text_preds = []
|
|
154
|
+
crop_orientation_preds = []
|
|
155
|
+
if len(loc_preds) > 0:
|
|
156
|
+
# Text & crop orientation predictions at page level
|
|
157
|
+
_idx = 0
|
|
158
|
+
for page_boxes in loc_preds:
|
|
159
|
+
text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]])
|
|
160
|
+
crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]])
|
|
161
|
+
_idx += page_boxes.shape[0]
|
|
162
|
+
|
|
163
|
+
return loc_preds, text_preds, crop_orientation_preds
|
|
164
|
+
|
|
165
|
+
def add_hook(self, hook: Callable) -> None:
|
|
166
|
+
"""Add a hook to the predictor
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
----
|
|
170
|
+
hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
|
|
171
|
+
"""
|
|
172
|
+
self.hooks.append(hook)
|
|
173
|
+
|
|
174
|
+
def list_archs(self) -> Dict[str, List[str]]:
|
|
175
|
+
return {"detection_archs": DETECTION_ARCHS, "recognition_archs": RECOGNITION_ARCHS}
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, List
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from onnxtr.io.elements import Document
|
|
11
|
+
from onnxtr.models._utils import estimate_orientation, get_language
|
|
12
|
+
from onnxtr.models.detection.predictor import DetectionPredictor
|
|
13
|
+
from onnxtr.models.recognition.predictor import RecognitionPredictor
|
|
14
|
+
from onnxtr.utils.geometry import rotate_image
|
|
15
|
+
from onnxtr.utils.repr import NestedObject
|
|
16
|
+
|
|
17
|
+
from .base import _OCRPredictor
|
|
18
|
+
|
|
19
|
+
__all__ = ["OCRPredictor"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OCRPredictor(NestedObject, _OCRPredictor):
|
|
23
|
+
"""Implements an object able to localize and identify text elements in a set of documents
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
----
|
|
27
|
+
det_predictor: detection module
|
|
28
|
+
reco_predictor: recognition module
|
|
29
|
+
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
30
|
+
without rotated textual elements.
|
|
31
|
+
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
|
|
32
|
+
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
|
|
33
|
+
accordingly. Doing so will improve performances for documents with page-uniform rotations.
|
|
34
|
+
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
|
|
35
|
+
page. Doing so will slightly deteriorate the overall latency.
|
|
36
|
+
detect_language: if True, the language prediction will be added to the predictions for each
|
|
37
|
+
page. Doing so will slightly deteriorate the overall latency.
|
|
38
|
+
**kwargs: keyword args of `DocumentBuilder`
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
_children_names = ["det_predictor", "reco_predictor", "doc_builder"]
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
det_predictor: DetectionPredictor,
|
|
46
|
+
reco_predictor: RecognitionPredictor,
|
|
47
|
+
assume_straight_pages: bool = True,
|
|
48
|
+
straighten_pages: bool = False,
|
|
49
|
+
preserve_aspect_ratio: bool = True,
|
|
50
|
+
symmetric_pad: bool = True,
|
|
51
|
+
detect_orientation: bool = False,
|
|
52
|
+
detect_language: bool = False,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
) -> None:
|
|
55
|
+
self.det_predictor = det_predictor
|
|
56
|
+
self.reco_predictor = reco_predictor
|
|
57
|
+
_OCRPredictor.__init__(
|
|
58
|
+
self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
|
|
59
|
+
)
|
|
60
|
+
self.detect_orientation = detect_orientation
|
|
61
|
+
self.detect_language = detect_language
|
|
62
|
+
|
|
63
|
+
def __call__(
|
|
64
|
+
self,
|
|
65
|
+
pages: List[np.ndarray],
|
|
66
|
+
**kwargs: Any,
|
|
67
|
+
) -> Document:
|
|
68
|
+
# Dimension check
|
|
69
|
+
if any(page.ndim != 3 for page in pages):
|
|
70
|
+
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
71
|
+
|
|
72
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
73
|
+
|
|
74
|
+
# Localize text elements
|
|
75
|
+
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
76
|
+
|
|
77
|
+
# Detect document rotation and rotate pages
|
|
78
|
+
seg_maps = [
|
|
79
|
+
np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
|
|
80
|
+
for out_map in out_maps
|
|
81
|
+
]
|
|
82
|
+
if self.detect_orientation:
|
|
83
|
+
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
84
|
+
orientations = [
|
|
85
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
|
|
86
|
+
]
|
|
87
|
+
else:
|
|
88
|
+
orientations = None
|
|
89
|
+
if self.straighten_pages:
|
|
90
|
+
origin_page_orientations = (
|
|
91
|
+
origin_page_orientations
|
|
92
|
+
if self.detect_orientation
|
|
93
|
+
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
94
|
+
)
|
|
95
|
+
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
96
|
+
# forward again to get predictions on straight pages
|
|
97
|
+
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
98
|
+
|
|
99
|
+
loc_preds = [loc_pred[0] for loc_pred in loc_preds]
|
|
100
|
+
|
|
101
|
+
# Rectify crops if aspect ratio
|
|
102
|
+
loc_preds = self._remove_padding(pages, loc_preds)
|
|
103
|
+
|
|
104
|
+
# Apply hooks to loc_preds if any
|
|
105
|
+
for hook in self.hooks:
|
|
106
|
+
loc_preds = hook(loc_preds)
|
|
107
|
+
|
|
108
|
+
# Crop images
|
|
109
|
+
crops, loc_preds = self._prepare_crops(
|
|
110
|
+
pages,
|
|
111
|
+
loc_preds, # type: ignore[arg-type]
|
|
112
|
+
channels_last=True,
|
|
113
|
+
assume_straight_pages=self.assume_straight_pages,
|
|
114
|
+
)
|
|
115
|
+
# Rectify crop orientation and get crop orientation predictions
|
|
116
|
+
crop_orientations: Any = []
|
|
117
|
+
if not self.assume_straight_pages:
|
|
118
|
+
crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
|
|
119
|
+
crop_orientations = [
|
|
120
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
# Identify character sequences
|
|
124
|
+
word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
|
|
125
|
+
if not crop_orientations:
|
|
126
|
+
crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
|
|
127
|
+
|
|
128
|
+
boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
|
|
129
|
+
|
|
130
|
+
if self.detect_language:
|
|
131
|
+
languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
|
|
132
|
+
languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
|
|
133
|
+
else:
|
|
134
|
+
languages_dict = None
|
|
135
|
+
|
|
136
|
+
out = self.doc_builder(
|
|
137
|
+
pages,
|
|
138
|
+
boxes,
|
|
139
|
+
text_preds,
|
|
140
|
+
origin_page_shapes, # type: ignore[arg-type]
|
|
141
|
+
crop_orientations,
|
|
142
|
+
orientations,
|
|
143
|
+
languages_dict,
|
|
144
|
+
)
|
|
145
|
+
return out
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import *
|