onnxtr 0.5.1__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/contrib/__init__.py +1 -0
- onnxtr/contrib/artefacts.py +6 -8
- onnxtr/contrib/base.py +7 -16
- onnxtr/file_utils.py +1 -3
- onnxtr/io/elements.py +45 -59
- onnxtr/io/html.py +0 -2
- onnxtr/io/image.py +1 -4
- onnxtr/io/pdf.py +3 -5
- onnxtr/io/reader.py +4 -10
- onnxtr/models/_utils.py +10 -17
- onnxtr/models/builder.py +17 -30
- onnxtr/models/classification/models/mobilenet.py +7 -12
- onnxtr/models/classification/predictor/base.py +6 -7
- onnxtr/models/classification/zoo.py +25 -11
- onnxtr/models/detection/_utils/base.py +3 -7
- onnxtr/models/detection/core.py +2 -8
- onnxtr/models/detection/models/differentiable_binarization.py +10 -17
- onnxtr/models/detection/models/fast.py +10 -17
- onnxtr/models/detection/models/linknet.py +10 -17
- onnxtr/models/detection/postprocessor/base.py +3 -9
- onnxtr/models/detection/predictor/base.py +4 -5
- onnxtr/models/detection/zoo.py +20 -6
- onnxtr/models/engine.py +9 -9
- onnxtr/models/factory/hub.py +3 -7
- onnxtr/models/predictor/base.py +29 -30
- onnxtr/models/predictor/predictor.py +4 -5
- onnxtr/models/preprocessor/base.py +8 -12
- onnxtr/models/recognition/core.py +0 -1
- onnxtr/models/recognition/models/crnn.py +11 -23
- onnxtr/models/recognition/models/master.py +9 -15
- onnxtr/models/recognition/models/parseq.py +8 -12
- onnxtr/models/recognition/models/sar.py +8 -12
- onnxtr/models/recognition/models/vitstr.py +9 -15
- onnxtr/models/recognition/predictor/_utils.py +6 -9
- onnxtr/models/recognition/predictor/base.py +3 -3
- onnxtr/models/recognition/utils.py +2 -7
- onnxtr/models/recognition/zoo.py +19 -7
- onnxtr/models/zoo.py +7 -9
- onnxtr/transforms/base.py +17 -6
- onnxtr/utils/common_types.py +7 -8
- onnxtr/utils/data.py +7 -11
- onnxtr/utils/fonts.py +1 -6
- onnxtr/utils/geometry.py +18 -49
- onnxtr/utils/multithreading.py +3 -5
- onnxtr/utils/reconstitution.py +6 -8
- onnxtr/utils/repr.py +1 -2
- onnxtr/utils/visualization.py +12 -21
- onnxtr/utils/vocabs.py +1 -2
- onnxtr/version.py +1 -1
- {onnxtr-0.5.1.dist-info → onnxtr-0.6.0.dist-info}/METADATA +70 -41
- onnxtr-0.6.0.dist-info/RECORD +75 -0
- {onnxtr-0.5.1.dist-info → onnxtr-0.6.0.dist-info}/WHEEL +1 -1
- onnxtr-0.5.1.dist-info/RECORD +0 -75
- {onnxtr-0.5.1.dist-info → onnxtr-0.6.0.dist-info}/LICENSE +0 -0
- {onnxtr-0.5.1.dist-info → onnxtr-0.6.0.dist-info}/top_level.txt +0 -0
- {onnxtr-0.5.1.dist-info → onnxtr-0.6.0.dist-info}/zip-safe +0 -0
onnxtr/models/_utils.py
CHANGED
|
@@ -5,7 +5,6 @@
|
|
|
5
5
|
|
|
6
6
|
from math import floor
|
|
7
7
|
from statistics import median_low
|
|
8
|
-
from typing import List, Optional, Tuple
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -20,11 +19,9 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
|
|
|
20
19
|
"""Get the maximum shape ratio of a contour.
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
22
|
contour: the contour from cv2.findContour
|
|
25
23
|
|
|
26
24
|
Returns:
|
|
27
|
-
-------
|
|
28
25
|
the maximum shape ratio
|
|
29
26
|
"""
|
|
30
27
|
_, (w, h), _ = cv2.minAreaRect(contour)
|
|
@@ -33,7 +30,7 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
|
|
|
33
30
|
|
|
34
31
|
def estimate_orientation(
|
|
35
32
|
img: np.ndarray,
|
|
36
|
-
general_page_orientation:
|
|
33
|
+
general_page_orientation: tuple[int, float] | None = None,
|
|
37
34
|
n_ct: int = 70,
|
|
38
35
|
ratio_threshold_for_lines: float = 3,
|
|
39
36
|
min_confidence: float = 0.2,
|
|
@@ -43,7 +40,6 @@ def estimate_orientation(
|
|
|
43
40
|
lines of the document and the assumption that they should be horizontal.
|
|
44
41
|
|
|
45
42
|
Args:
|
|
46
|
-
----
|
|
47
43
|
img: the img or bitmap to analyze (H, W, C)
|
|
48
44
|
general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence)
|
|
49
45
|
estimated by a model
|
|
@@ -53,7 +49,6 @@ def estimate_orientation(
|
|
|
53
49
|
lower_area: the minimum area of a contour to be considered
|
|
54
50
|
|
|
55
51
|
Returns:
|
|
56
|
-
-------
|
|
57
52
|
the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation)
|
|
58
53
|
"""
|
|
59
54
|
assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
|
|
@@ -64,13 +59,13 @@ def estimate_orientation(
|
|
|
64
59
|
gray_img = cv2.medianBlur(gray_img, 5)
|
|
65
60
|
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
|
66
61
|
else:
|
|
67
|
-
thresh = img.astype(np.uint8)
|
|
62
|
+
thresh = img.astype(np.uint8)
|
|
68
63
|
|
|
69
64
|
page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
|
|
70
65
|
if page_orientation and orientation_confidence >= min_confidence:
|
|
71
66
|
# We rotate the image to the general orientation which improves the detection
|
|
72
67
|
# No expand needed bitmap is already padded
|
|
73
|
-
thresh = rotate_image(thresh, -page_orientation)
|
|
68
|
+
thresh = rotate_image(thresh, -page_orientation)
|
|
74
69
|
else: # That's only required if we do not work on the detection models bin map
|
|
75
70
|
# try to merge words in lines
|
|
76
71
|
(h, w) = img.shape[:2]
|
|
@@ -91,7 +86,7 @@ def estimate_orientation(
|
|
|
91
86
|
|
|
92
87
|
angles = []
|
|
93
88
|
for contour in contours[:n_ct]:
|
|
94
|
-
_, (w, h), angle = cv2.minAreaRect(contour)
|
|
89
|
+
_, (w, h), angle = cv2.minAreaRect(contour)
|
|
95
90
|
if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
|
|
96
91
|
angles.append(angle)
|
|
97
92
|
elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
|
|
@@ -119,9 +114,9 @@ def estimate_orientation(
|
|
|
119
114
|
|
|
120
115
|
|
|
121
116
|
def rectify_crops(
|
|
122
|
-
crops:
|
|
123
|
-
orientations:
|
|
124
|
-
) ->
|
|
117
|
+
crops: list[np.ndarray],
|
|
118
|
+
orientations: list[int],
|
|
119
|
+
) -> list[np.ndarray]:
|
|
125
120
|
"""Rotate each crop of the list according to the predicted orientation:
|
|
126
121
|
0: already straight, no rotation
|
|
127
122
|
1: 90 ccw, rotate 3 times ccw
|
|
@@ -139,8 +134,8 @@ def rectify_crops(
|
|
|
139
134
|
|
|
140
135
|
def rectify_loc_preds(
|
|
141
136
|
page_loc_preds: np.ndarray,
|
|
142
|
-
orientations:
|
|
143
|
-
) ->
|
|
137
|
+
orientations: list[int],
|
|
138
|
+
) -> np.ndarray | None:
|
|
144
139
|
"""Orient the quadrangle (Polygon4P) according to the predicted orientation,
|
|
145
140
|
so that the points are in this order: top L, top R, bot R, bot L if the crop is readable
|
|
146
141
|
"""
|
|
@@ -157,16 +152,14 @@ def rectify_loc_preds(
|
|
|
157
152
|
)
|
|
158
153
|
|
|
159
154
|
|
|
160
|
-
def get_language(text: str) ->
|
|
155
|
+
def get_language(text: str) -> tuple[str, float]:
|
|
161
156
|
"""Get languages of a text using langdetect model.
|
|
162
157
|
Get the language with the highest probability or no language if only a few words or a low probability
|
|
163
158
|
|
|
164
159
|
Args:
|
|
165
|
-
----
|
|
166
160
|
text (str): text
|
|
167
161
|
|
|
168
162
|
Returns:
|
|
169
|
-
-------
|
|
170
163
|
The detected language in ISO 639 code and confidence score
|
|
171
164
|
"""
|
|
172
165
|
try:
|
onnxtr/models/builder.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from scipy.cluster.hierarchy import fclusterdata
|
|
@@ -20,7 +20,6 @@ class DocumentBuilder(NestedObject):
|
|
|
20
20
|
"""Implements a document builder
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
resolve_lines: whether words should be automatically grouped into lines
|
|
25
24
|
resolve_blocks: whether lines should be automatically grouped into blocks
|
|
26
25
|
paragraph_break: relative length of the minimum space separating paragraphs
|
|
@@ -41,15 +40,13 @@ class DocumentBuilder(NestedObject):
|
|
|
41
40
|
self.export_as_straight_boxes = export_as_straight_boxes
|
|
42
41
|
|
|
43
42
|
@staticmethod
|
|
44
|
-
def _sort_boxes(boxes: np.ndarray) ->
|
|
43
|
+
def _sort_boxes(boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
45
44
|
"""Sort bounding boxes from top to bottom, left to right
|
|
46
45
|
|
|
47
46
|
Args:
|
|
48
|
-
----
|
|
49
47
|
boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox)
|
|
50
48
|
|
|
51
49
|
Returns:
|
|
52
|
-
-------
|
|
53
50
|
tuple: indices of ordered boxes of shape (N,), boxes
|
|
54
51
|
If straight boxes are passed tpo the function, boxes are unchanged
|
|
55
52
|
else: boxes returned are straight boxes fitted to the straightened rotated boxes
|
|
@@ -65,16 +62,14 @@ class DocumentBuilder(NestedObject):
|
|
|
65
62
|
boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1)
|
|
66
63
|
return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes
|
|
67
64
|
|
|
68
|
-
def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs:
|
|
65
|
+
def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: list[int]) -> list[list[int]]:
|
|
69
66
|
"""Split a line in sub_lines
|
|
70
67
|
|
|
71
68
|
Args:
|
|
72
|
-
----
|
|
73
69
|
boxes: bounding boxes of shape (N, 4)
|
|
74
70
|
word_idcs: list of indexes for the words of the line
|
|
75
71
|
|
|
76
72
|
Returns:
|
|
77
|
-
-------
|
|
78
73
|
A list of (sub-)lines computed from the original line (words)
|
|
79
74
|
"""
|
|
80
75
|
lines = []
|
|
@@ -105,15 +100,13 @@ class DocumentBuilder(NestedObject):
|
|
|
105
100
|
|
|
106
101
|
return lines
|
|
107
102
|
|
|
108
|
-
def _resolve_lines(self, boxes: np.ndarray) ->
|
|
103
|
+
def _resolve_lines(self, boxes: np.ndarray) -> list[list[int]]:
|
|
109
104
|
"""Order boxes to group them in lines
|
|
110
105
|
|
|
111
106
|
Args:
|
|
112
|
-
----
|
|
113
107
|
boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox
|
|
114
108
|
|
|
115
109
|
Returns:
|
|
116
|
-
-------
|
|
117
110
|
nested list of box indices
|
|
118
111
|
"""
|
|
119
112
|
# Sort boxes, and straighten the boxes if they are rotated
|
|
@@ -153,16 +146,14 @@ class DocumentBuilder(NestedObject):
|
|
|
153
146
|
return lines
|
|
154
147
|
|
|
155
148
|
@staticmethod
|
|
156
|
-
def _resolve_blocks(boxes: np.ndarray, lines:
|
|
149
|
+
def _resolve_blocks(boxes: np.ndarray, lines: list[list[int]]) -> list[list[list[int]]]:
|
|
157
150
|
"""Order lines to group them in blocks
|
|
158
151
|
|
|
159
152
|
Args:
|
|
160
|
-
----
|
|
161
153
|
boxes: bounding boxes of shape (N, 4) or (N, 4, 2)
|
|
162
154
|
lines: list of lines, each line is a list of idx
|
|
163
155
|
|
|
164
156
|
Returns:
|
|
165
|
-
-------
|
|
166
157
|
nested list of box indices
|
|
167
158
|
"""
|
|
168
159
|
# Resolve enclosing boxes of lines
|
|
@@ -207,7 +198,7 @@ class DocumentBuilder(NestedObject):
|
|
|
207
198
|
# Compute clusters
|
|
208
199
|
clusters = fclusterdata(box_features, t=0.1, depth=4, criterion="distance", metric="euclidean")
|
|
209
200
|
|
|
210
|
-
_blocks:
|
|
201
|
+
_blocks: dict[int, list[int]] = {}
|
|
211
202
|
# Form clusters
|
|
212
203
|
for line_idx, cluster_idx in enumerate(clusters):
|
|
213
204
|
if cluster_idx in _blocks.keys():
|
|
@@ -224,13 +215,12 @@ class DocumentBuilder(NestedObject):
|
|
|
224
215
|
self,
|
|
225
216
|
boxes: np.ndarray,
|
|
226
217
|
objectness_scores: np.ndarray,
|
|
227
|
-
word_preds:
|
|
228
|
-
crop_orientations:
|
|
229
|
-
) ->
|
|
218
|
+
word_preds: list[tuple[str, float]],
|
|
219
|
+
crop_orientations: list[dict[str, Any]],
|
|
220
|
+
) -> list[Block]:
|
|
230
221
|
"""Gather independent words in structured blocks
|
|
231
222
|
|
|
232
223
|
Args:
|
|
233
|
-
----
|
|
234
224
|
boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
|
|
235
225
|
objectness_scores: objectness scores of all detected words of the page, of shape N
|
|
236
226
|
word_preds: list of all detected words of the page, of shape N
|
|
@@ -238,7 +228,6 @@ class DocumentBuilder(NestedObject):
|
|
|
238
228
|
the general orientation (orientations + confidences) of the crops
|
|
239
229
|
|
|
240
230
|
Returns:
|
|
241
|
-
-------
|
|
242
231
|
list of block elements
|
|
243
232
|
"""
|
|
244
233
|
if boxes.shape[0] != len(word_preds):
|
|
@@ -295,19 +284,18 @@ class DocumentBuilder(NestedObject):
|
|
|
295
284
|
|
|
296
285
|
def __call__(
|
|
297
286
|
self,
|
|
298
|
-
pages:
|
|
299
|
-
boxes:
|
|
300
|
-
objectness_scores:
|
|
301
|
-
text_preds:
|
|
302
|
-
page_shapes:
|
|
303
|
-
crop_orientations:
|
|
304
|
-
orientations:
|
|
305
|
-
languages:
|
|
287
|
+
pages: list[np.ndarray],
|
|
288
|
+
boxes: list[np.ndarray],
|
|
289
|
+
objectness_scores: list[np.ndarray],
|
|
290
|
+
text_preds: list[list[tuple[str, float]]],
|
|
291
|
+
page_shapes: list[tuple[int, int]],
|
|
292
|
+
crop_orientations: list[dict[str, Any]],
|
|
293
|
+
orientations: list[dict[str, Any]] | None = None,
|
|
294
|
+
languages: list[dict[str, Any]] | None = None,
|
|
306
295
|
) -> Document:
|
|
307
296
|
"""Re-arrange detected words into structured blocks
|
|
308
297
|
|
|
309
298
|
Args:
|
|
310
|
-
----
|
|
311
299
|
pages: list of N elements, where each element represents the page image
|
|
312
300
|
boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4)
|
|
313
301
|
or (*, 4, 2) for all words for a given page
|
|
@@ -322,7 +310,6 @@ class DocumentBuilder(NestedObject):
|
|
|
322
310
|
where each element is a dictionary containing the language (language + confidence)
|
|
323
311
|
|
|
324
312
|
Returns:
|
|
325
|
-
-------
|
|
326
313
|
document object
|
|
327
314
|
"""
|
|
328
315
|
if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
|
|
@@ -18,7 +18,7 @@ __all__ = [
|
|
|
18
18
|
"mobilenet_v3_small_page_orientation",
|
|
19
19
|
]
|
|
20
20
|
|
|
21
|
-
default_cfgs:
|
|
21
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
22
|
"mobilenet_v3_small_crop_orientation": {
|
|
23
23
|
"mean": (0.694, 0.695, 0.693),
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -42,7 +42,6 @@ class MobileNetV3(Engine):
|
|
|
42
42
|
"""MobileNetV3 Onnx loader
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
|
-
----
|
|
46
45
|
model_path: path or url to onnx model file
|
|
47
46
|
engine_cfg: configuration for the inference engine
|
|
48
47
|
cfg: configuration dictionary
|
|
@@ -52,8 +51,8 @@ class MobileNetV3(Engine):
|
|
|
52
51
|
def __init__(
|
|
53
52
|
self,
|
|
54
53
|
model_path: str,
|
|
55
|
-
engine_cfg:
|
|
56
|
-
cfg:
|
|
54
|
+
engine_cfg: EngineConfig | None = None,
|
|
55
|
+
cfg: dict[str, Any] | None = None,
|
|
57
56
|
**kwargs: Any,
|
|
58
57
|
) -> None:
|
|
59
58
|
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
@@ -71,7 +70,7 @@ def _mobilenet_v3(
|
|
|
71
70
|
arch: str,
|
|
72
71
|
model_path: str,
|
|
73
72
|
load_in_8_bit: bool = False,
|
|
74
|
-
engine_cfg:
|
|
73
|
+
engine_cfg: EngineConfig | None = None,
|
|
75
74
|
**kwargs: Any,
|
|
76
75
|
) -> MobileNetV3:
|
|
77
76
|
# Patch the url
|
|
@@ -83,7 +82,7 @@ def _mobilenet_v3(
|
|
|
83
82
|
def mobilenet_v3_small_crop_orientation(
|
|
84
83
|
model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
|
|
85
84
|
load_in_8_bit: bool = False,
|
|
86
|
-
engine_cfg:
|
|
85
|
+
engine_cfg: EngineConfig | None = None,
|
|
87
86
|
**kwargs: Any,
|
|
88
87
|
) -> MobileNetV3:
|
|
89
88
|
"""MobileNetV3-Small architecture as described in
|
|
@@ -97,14 +96,12 @@ def mobilenet_v3_small_crop_orientation(
|
|
|
97
96
|
>>> out = model(input_tensor)
|
|
98
97
|
|
|
99
98
|
Args:
|
|
100
|
-
----
|
|
101
99
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
102
100
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
103
101
|
engine_cfg: configuration for the inference engine
|
|
104
102
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
105
103
|
|
|
106
104
|
Returns:
|
|
107
|
-
-------
|
|
108
105
|
MobileNetV3
|
|
109
106
|
"""
|
|
110
107
|
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -113,7 +110,7 @@ def mobilenet_v3_small_crop_orientation(
|
|
|
113
110
|
def mobilenet_v3_small_page_orientation(
|
|
114
111
|
model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
|
|
115
112
|
load_in_8_bit: bool = False,
|
|
116
|
-
engine_cfg:
|
|
113
|
+
engine_cfg: EngineConfig | None = None,
|
|
117
114
|
**kwargs: Any,
|
|
118
115
|
) -> MobileNetV3:
|
|
119
116
|
"""MobileNetV3-Small architecture as described in
|
|
@@ -127,14 +124,12 @@ def mobilenet_v3_small_page_orientation(
|
|
|
127
124
|
>>> out = model(input_tensor)
|
|
128
125
|
|
|
129
126
|
Args:
|
|
130
|
-
----
|
|
131
127
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
132
128
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
133
129
|
engine_cfg: configuration for the inference engine
|
|
134
130
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
135
131
|
|
|
136
132
|
Returns:
|
|
137
|
-
-------
|
|
138
133
|
MobileNetV3
|
|
139
134
|
"""
|
|
140
135
|
return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from scipy.special import softmax
|
|
@@ -19,26 +19,25 @@ class OrientationPredictor(NestedObject):
|
|
|
19
19
|
4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
|
|
20
20
|
|
|
21
21
|
Args:
|
|
22
|
-
----
|
|
23
22
|
pre_processor: transform inputs for easier batched model inference
|
|
24
23
|
model: core classification architecture (backbone + classification head)
|
|
25
24
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
26
25
|
"""
|
|
27
26
|
|
|
28
|
-
_children_names:
|
|
27
|
+
_children_names: list[str] = ["pre_processor", "model"]
|
|
29
28
|
|
|
30
29
|
def __init__(
|
|
31
30
|
self,
|
|
32
|
-
pre_processor:
|
|
33
|
-
model:
|
|
31
|
+
pre_processor: PreProcessor | None,
|
|
32
|
+
model: Any | None,
|
|
34
33
|
) -> None:
|
|
35
34
|
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
|
|
36
35
|
self.model = model
|
|
37
36
|
|
|
38
37
|
def __call__(
|
|
39
38
|
self,
|
|
40
|
-
inputs:
|
|
41
|
-
) ->
|
|
39
|
+
inputs: list[np.ndarray],
|
|
40
|
+
) -> list[list[int] | list[float]]:
|
|
42
41
|
# Dimension check
|
|
43
42
|
if any(input.ndim != 3 for input in inputs):
|
|
44
43
|
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from onnxtr.models.engine import EngineConfig
|
|
9
9
|
|
|
@@ -13,14 +13,14 @@ from .predictor import OrientationPredictor
|
|
|
13
13
|
|
|
14
14
|
__all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
|
|
15
15
|
|
|
16
|
-
ORIENTATION_ARCHS:
|
|
16
|
+
ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def _orientation_predictor(
|
|
20
20
|
arch: Any,
|
|
21
21
|
model_type: str,
|
|
22
22
|
load_in_8_bit: bool = False,
|
|
23
|
-
engine_cfg:
|
|
23
|
+
engine_cfg: EngineConfig | None = None,
|
|
24
24
|
disabled: bool = False,
|
|
25
25
|
**kwargs: Any,
|
|
26
26
|
) -> OrientationPredictor:
|
|
@@ -51,8 +51,9 @@ def _orientation_predictor(
|
|
|
51
51
|
|
|
52
52
|
def crop_orientation_predictor(
|
|
53
53
|
arch: Any = "mobilenet_v3_small_crop_orientation",
|
|
54
|
+
batch_size: int = 512,
|
|
54
55
|
load_in_8_bit: bool = False,
|
|
55
|
-
engine_cfg:
|
|
56
|
+
engine_cfg: EngineConfig | None = None,
|
|
56
57
|
**kwargs: Any,
|
|
57
58
|
) -> OrientationPredictor:
|
|
58
59
|
"""Crop orientation classification architecture.
|
|
@@ -64,24 +65,31 @@ def crop_orientation_predictor(
|
|
|
64
65
|
>>> out = model([input_crop])
|
|
65
66
|
|
|
66
67
|
Args:
|
|
67
|
-
----
|
|
68
68
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
|
|
69
|
+
batch_size: number of samples the model processes in parallel
|
|
69
70
|
load_in_8_bit: load the 8-bit quantized version of the model
|
|
70
71
|
engine_cfg: configuration of inference engine
|
|
71
72
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
72
73
|
|
|
73
74
|
Returns:
|
|
74
|
-
-------
|
|
75
75
|
OrientationPredictor
|
|
76
76
|
"""
|
|
77
77
|
model_type = "crop"
|
|
78
|
-
return _orientation_predictor(
|
|
78
|
+
return _orientation_predictor(
|
|
79
|
+
arch=arch,
|
|
80
|
+
batch_size=batch_size,
|
|
81
|
+
model_type=model_type,
|
|
82
|
+
load_in_8_bit=load_in_8_bit,
|
|
83
|
+
engine_cfg=engine_cfg,
|
|
84
|
+
**kwargs,
|
|
85
|
+
)
|
|
79
86
|
|
|
80
87
|
|
|
81
88
|
def page_orientation_predictor(
|
|
82
89
|
arch: Any = "mobilenet_v3_small_page_orientation",
|
|
90
|
+
batch_size: int = 2,
|
|
83
91
|
load_in_8_bit: bool = False,
|
|
84
|
-
engine_cfg:
|
|
92
|
+
engine_cfg: EngineConfig | None = None,
|
|
85
93
|
**kwargs: Any,
|
|
86
94
|
) -> OrientationPredictor:
|
|
87
95
|
"""Page orientation classification architecture.
|
|
@@ -93,15 +101,21 @@ def page_orientation_predictor(
|
|
|
93
101
|
>>> out = model([input_page])
|
|
94
102
|
|
|
95
103
|
Args:
|
|
96
|
-
----
|
|
97
104
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
|
|
105
|
+
batch_size: number of samples the model processes in parallel
|
|
98
106
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
99
107
|
engine_cfg: configuration for the inference engine
|
|
100
108
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
101
109
|
|
|
102
110
|
Returns:
|
|
103
|
-
-------
|
|
104
111
|
OrientationPredictor
|
|
105
112
|
"""
|
|
106
113
|
model_type = "page"
|
|
107
|
-
return _orientation_predictor(
|
|
114
|
+
return _orientation_predictor(
|
|
115
|
+
arch=arch,
|
|
116
|
+
batch_size=batch_size,
|
|
117
|
+
model_type=model_type,
|
|
118
|
+
load_in_8_bit=load_in_8_bit,
|
|
119
|
+
engine_cfg=engine_cfg,
|
|
120
|
+
**kwargs,
|
|
121
|
+
)
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -11,17 +10,15 @@ __all__ = ["_remove_padding"]
|
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
def _remove_padding(
|
|
14
|
-
pages:
|
|
15
|
-
loc_preds:
|
|
13
|
+
pages: list[np.ndarray],
|
|
14
|
+
loc_preds: list[np.ndarray],
|
|
16
15
|
preserve_aspect_ratio: bool,
|
|
17
16
|
symmetric_pad: bool,
|
|
18
17
|
assume_straight_pages: bool,
|
|
19
|
-
) ->
|
|
18
|
+
) -> list[np.ndarray]:
|
|
20
19
|
"""Remove padding from the localization predictions
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
|
-
|
|
25
22
|
pages: list of pages
|
|
26
23
|
loc_preds: list of localization predictions
|
|
27
24
|
preserve_aspect_ratio: whether the aspect ratio was preserved during padding
|
|
@@ -29,7 +26,6 @@ def _remove_padding(
|
|
|
29
26
|
assume_straight_pages: whether the pages are assumed to be straight
|
|
30
27
|
|
|
31
28
|
Returns:
|
|
32
|
-
-------
|
|
33
29
|
list of unpaded localization predictions
|
|
34
30
|
"""
|
|
35
31
|
if preserve_aspect_ratio:
|
onnxtr/models/detection/core.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List
|
|
7
6
|
|
|
8
7
|
import cv2
|
|
9
8
|
import numpy as np
|
|
@@ -17,7 +16,6 @@ class DetectionPostProcessor(NestedObject):
|
|
|
17
16
|
"""Abstract class to postprocess the raw output of the model
|
|
18
17
|
|
|
19
18
|
Args:
|
|
20
|
-
----
|
|
21
19
|
box_thresh (float): minimal objectness score to consider a box
|
|
22
20
|
bin_thresh (float): threshold to apply to segmentation raw heatmap
|
|
23
21
|
assume straight_pages (bool): if True, fit straight boxes only
|
|
@@ -37,13 +35,11 @@ class DetectionPostProcessor(NestedObject):
|
|
|
37
35
|
"""Compute the confidence score for a polygon : mean of the p values on the polygon
|
|
38
36
|
|
|
39
37
|
Args:
|
|
40
|
-
----
|
|
41
38
|
pred (np.ndarray): p map returned by the model
|
|
42
39
|
points: coordinates of the polygon
|
|
43
40
|
assume_straight_pages: if True, fit straight boxes only
|
|
44
41
|
|
|
45
42
|
Returns:
|
|
46
|
-
-------
|
|
47
43
|
polygon objectness
|
|
48
44
|
"""
|
|
49
45
|
h, w = pred.shape[:2]
|
|
@@ -71,17 +67,15 @@ class DetectionPostProcessor(NestedObject):
|
|
|
71
67
|
def __call__(
|
|
72
68
|
self,
|
|
73
69
|
proba_map,
|
|
74
|
-
) ->
|
|
70
|
+
) -> list[list[np.ndarray]]:
|
|
75
71
|
"""Performs postprocessing for a list of model outputs
|
|
76
72
|
|
|
77
73
|
Args:
|
|
78
|
-
----
|
|
79
74
|
proba_map: probability map of shape (N, H, W, C)
|
|
80
75
|
|
|
81
76
|
Returns:
|
|
82
|
-
-------
|
|
83
77
|
list of N class predictions (for each input sample), where each class predictions is a list of C tensors
|
|
84
|
-
|
|
78
|
+
of shape (*, 5) or (*, 6)
|
|
85
79
|
"""
|
|
86
80
|
if proba_map.ndim != 4:
|
|
87
81
|
raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.")
|