onnxtr 0.1.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/io/elements.py +17 -4
- onnxtr/io/pdf.py +6 -3
- onnxtr/models/__init__.py +1 -0
- onnxtr/models/_utils.py +57 -20
- onnxtr/models/builder.py +24 -9
- onnxtr/models/classification/models/mobilenet.py +25 -7
- onnxtr/models/classification/predictor/base.py +1 -0
- onnxtr/models/classification/zoo.py +22 -7
- onnxtr/models/detection/_utils/__init__.py +1 -0
- onnxtr/models/detection/_utils/base.py +66 -0
- onnxtr/models/detection/models/differentiable_binarization.py +41 -11
- onnxtr/models/detection/models/fast.py +37 -9
- onnxtr/models/detection/models/linknet.py +39 -9
- onnxtr/models/detection/postprocessor/base.py +4 -3
- onnxtr/models/detection/predictor/base.py +15 -1
- onnxtr/models/detection/zoo.py +16 -3
- onnxtr/models/engine.py +75 -9
- onnxtr/models/predictor/base.py +69 -42
- onnxtr/models/predictor/predictor.py +22 -15
- onnxtr/models/recognition/models/crnn.py +39 -9
- onnxtr/models/recognition/models/master.py +19 -5
- onnxtr/models/recognition/models/parseq.py +20 -5
- onnxtr/models/recognition/models/sar.py +19 -5
- onnxtr/models/recognition/models/vitstr.py +31 -9
- onnxtr/models/recognition/zoo.py +12 -6
- onnxtr/models/zoo.py +22 -0
- onnxtr/py.typed +0 -0
- onnxtr/utils/geometry.py +33 -12
- onnxtr/version.py +1 -1
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/METADATA +81 -16
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/RECORD +35 -32
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/WHEEL +1 -1
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/top_level.txt +0 -1
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/LICENSE +0 -0
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/zip-safe +0 -0
onnxtr/io/elements.py
CHANGED
|
@@ -67,10 +67,11 @@ class Word(Element):
|
|
|
67
67
|
confidence: the confidence associated with the text prediction
|
|
68
68
|
geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
|
|
69
69
|
the page's size
|
|
70
|
+
objectness_score: the objectness score of the detection
|
|
70
71
|
crop_orientation: the general orientation of the crop in degrees and its confidence
|
|
71
72
|
"""
|
|
72
73
|
|
|
73
|
-
_exported_keys: List[str] = ["value", "confidence", "geometry", "crop_orientation"]
|
|
74
|
+
_exported_keys: List[str] = ["value", "confidence", "geometry", "objectness_score", "crop_orientation"]
|
|
74
75
|
_children_names: List[str] = []
|
|
75
76
|
|
|
76
77
|
def __init__(
|
|
@@ -78,12 +79,14 @@ class Word(Element):
|
|
|
78
79
|
value: str,
|
|
79
80
|
confidence: float,
|
|
80
81
|
geometry: Union[BoundingBox, np.ndarray],
|
|
82
|
+
objectness_score: float,
|
|
81
83
|
crop_orientation: Dict[str, Any],
|
|
82
84
|
) -> None:
|
|
83
85
|
super().__init__()
|
|
84
86
|
self.value = value
|
|
85
87
|
self.confidence = confidence
|
|
86
88
|
self.geometry = geometry
|
|
89
|
+
self.objectness_score = objectness_score
|
|
87
90
|
self.crop_orientation = crop_orientation
|
|
88
91
|
|
|
89
92
|
def render(self) -> str:
|
|
@@ -143,7 +146,7 @@ class Line(Element):
|
|
|
143
146
|
all words in it.
|
|
144
147
|
"""
|
|
145
148
|
|
|
146
|
-
_exported_keys: List[str] = ["geometry"]
|
|
149
|
+
_exported_keys: List[str] = ["geometry", "objectness_score"]
|
|
147
150
|
_children_names: List[str] = ["words"]
|
|
148
151
|
words: List[Word] = []
|
|
149
152
|
|
|
@@ -151,7 +154,11 @@ class Line(Element):
|
|
|
151
154
|
self,
|
|
152
155
|
words: List[Word],
|
|
153
156
|
geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
|
|
157
|
+
objectness_score: Optional[float] = None,
|
|
154
158
|
) -> None:
|
|
159
|
+
# Compute the objectness score of the line
|
|
160
|
+
if objectness_score is None:
|
|
161
|
+
objectness_score = float(np.mean([w.objectness_score for w in words]))
|
|
155
162
|
# Resolve the geometry using the smallest enclosing bounding box
|
|
156
163
|
if geometry is None:
|
|
157
164
|
# Check whether this is a rotated or straight box
|
|
@@ -160,6 +167,7 @@ class Line(Element):
|
|
|
160
167
|
|
|
161
168
|
super().__init__(words=words)
|
|
162
169
|
self.geometry = geometry
|
|
170
|
+
self.objectness_score = objectness_score
|
|
163
171
|
|
|
164
172
|
def render(self) -> str:
|
|
165
173
|
"""Renders the full text of the element"""
|
|
@@ -186,7 +194,7 @@ class Block(Element):
|
|
|
186
194
|
all lines and artefacts in it.
|
|
187
195
|
"""
|
|
188
196
|
|
|
189
|
-
_exported_keys: List[str] = ["geometry"]
|
|
197
|
+
_exported_keys: List[str] = ["geometry", "objectness_score"]
|
|
190
198
|
_children_names: List[str] = ["lines", "artefacts"]
|
|
191
199
|
lines: List[Line] = []
|
|
192
200
|
artefacts: List[Artefact] = []
|
|
@@ -196,7 +204,11 @@ class Block(Element):
|
|
|
196
204
|
lines: List[Line] = [],
|
|
197
205
|
artefacts: List[Artefact] = [],
|
|
198
206
|
geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
|
|
207
|
+
objectness_score: Optional[float] = None,
|
|
199
208
|
) -> None:
|
|
209
|
+
# Compute the objectness score of the line
|
|
210
|
+
if objectness_score is None:
|
|
211
|
+
objectness_score = float(np.mean([w.objectness_score for line in lines for w in line.words]))
|
|
200
212
|
# Resolve the geometry using the smallest enclosing bounding box
|
|
201
213
|
if geometry is None:
|
|
202
214
|
line_boxes = [word.geometry for line in lines for word in line.words]
|
|
@@ -208,6 +220,7 @@ class Block(Element):
|
|
|
208
220
|
|
|
209
221
|
super().__init__(lines=lines, artefacts=artefacts)
|
|
210
222
|
self.geometry = geometry
|
|
223
|
+
self.objectness_score = objectness_score
|
|
211
224
|
|
|
212
225
|
def render(self, line_break: str = "\n") -> str:
|
|
213
226
|
"""Renders the full text of the element"""
|
|
@@ -314,7 +327,7 @@ class Page(Element):
|
|
|
314
327
|
SubElement(
|
|
315
328
|
head,
|
|
316
329
|
"meta",
|
|
317
|
-
attrib={"name": "ocr-system", "content": f" {onnxtr.__version__}"}, # type: ignore[attr-defined]
|
|
330
|
+
attrib={"name": "ocr-system", "content": f"onnxtr {onnxtr.__version__}"}, # type: ignore[attr-defined]
|
|
318
331
|
)
|
|
319
332
|
SubElement(
|
|
320
333
|
head,
|
onnxtr/io/pdf.py
CHANGED
|
@@ -15,7 +15,7 @@ __all__ = ["read_pdf"]
|
|
|
15
15
|
|
|
16
16
|
def read_pdf(
|
|
17
17
|
file: AbstractFile,
|
|
18
|
-
scale:
|
|
18
|
+
scale: int = 2,
|
|
19
19
|
rgb_mode: bool = True,
|
|
20
20
|
password: Optional[str] = None,
|
|
21
21
|
**kwargs: Any,
|
|
@@ -38,5 +38,8 @@ def read_pdf(
|
|
|
38
38
|
the list of pages decoded as numpy ndarray of shape H x W x C
|
|
39
39
|
"""
|
|
40
40
|
# Rasterise pages to numpy ndarrays with pypdfium2
|
|
41
|
-
pdf = pdfium.PdfDocument(file, password=password
|
|
42
|
-
|
|
41
|
+
pdf = pdfium.PdfDocument(file, password=password)
|
|
42
|
+
try:
|
|
43
|
+
return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
|
|
44
|
+
finally:
|
|
45
|
+
pdf.close()
|
onnxtr/models/__init__.py
CHANGED
onnxtr/models/_utils.py
CHANGED
|
@@ -11,6 +11,8 @@ import cv2
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from langdetect import LangDetectException, detect_langs
|
|
13
13
|
|
|
14
|
+
from onnxtr.utils.geometry import rotate_image
|
|
15
|
+
|
|
14
16
|
__all__ = ["estimate_orientation", "get_language"]
|
|
15
17
|
|
|
16
18
|
|
|
@@ -29,56 +31,91 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
|
|
|
29
31
|
return max(w / h, h / w)
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
def estimate_orientation(
|
|
34
|
+
def estimate_orientation(
|
|
35
|
+
img: np.ndarray,
|
|
36
|
+
general_page_orientation: Optional[Tuple[int, float]] = None,
|
|
37
|
+
n_ct: int = 70,
|
|
38
|
+
ratio_threshold_for_lines: float = 3,
|
|
39
|
+
min_confidence: float = 0.2,
|
|
40
|
+
lower_area: int = 100,
|
|
41
|
+
) -> int:
|
|
33
42
|
"""Estimate the angle of the general document orientation based on the
|
|
34
43
|
lines of the document and the assumption that they should be horizontal.
|
|
35
44
|
|
|
36
45
|
Args:
|
|
37
46
|
----
|
|
38
47
|
img: the img or bitmap to analyze (H, W, C)
|
|
48
|
+
general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence)
|
|
49
|
+
estimated by a model
|
|
39
50
|
n_ct: the number of contours used for the orientation estimation
|
|
40
51
|
ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
|
|
52
|
+
min_confidence: the minimum confidence to consider the general_page_orientation
|
|
53
|
+
lower_area: the minimum area of a contour to be considered
|
|
41
54
|
|
|
42
55
|
Returns:
|
|
43
56
|
-------
|
|
44
|
-
the angle of the
|
|
57
|
+
the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation)
|
|
45
58
|
"""
|
|
46
59
|
assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
if
|
|
50
|
-
thresh = img.astype(np.uint8)
|
|
51
|
-
if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
|
|
60
|
+
thresh = None
|
|
61
|
+
# Convert image to grayscale if necessary
|
|
62
|
+
if img.shape[-1] == 3:
|
|
52
63
|
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
53
64
|
gray_img = cv2.medianBlur(gray_img, 5)
|
|
54
|
-
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
65
|
+
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
|
66
|
+
else:
|
|
67
|
+
thresh = img.astype(np.uint8) # type: ignore[assignment]
|
|
68
|
+
|
|
69
|
+
page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
|
|
70
|
+
if page_orientation and orientation_confidence >= min_confidence:
|
|
71
|
+
# We rotate the image to the general orientation which improves the detection
|
|
72
|
+
# No expand needed bitmap is already padded
|
|
73
|
+
thresh = rotate_image(thresh, -page_orientation) # type: ignore
|
|
74
|
+
else: # That's only required if we do not work on the detection models bin map
|
|
75
|
+
# try to merge words in lines
|
|
76
|
+
(h, w) = img.shape[:2]
|
|
77
|
+
k_x = max(1, (floor(w / 100)))
|
|
78
|
+
k_y = max(1, (floor(h / 100)))
|
|
79
|
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
|
|
80
|
+
thresh = cv2.dilate(thresh, kernel, iterations=1)
|
|
62
81
|
|
|
63
82
|
# extract contours
|
|
64
83
|
contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
|
65
84
|
|
|
66
|
-
# Sort contours
|
|
67
|
-
contours = sorted(
|
|
85
|
+
# Filter & Sort contours
|
|
86
|
+
contours = sorted(
|
|
87
|
+
[contour for contour in contours if cv2.contourArea(contour) > lower_area],
|
|
88
|
+
key=get_max_width_length_ratio,
|
|
89
|
+
reverse=True,
|
|
90
|
+
)
|
|
68
91
|
|
|
69
92
|
angles = []
|
|
70
93
|
for contour in contours[:n_ct]:
|
|
71
|
-
_, (w, h), angle = cv2.minAreaRect(contour)
|
|
94
|
+
_, (w, h), angle = cv2.minAreaRect(contour) # type: ignore[assignment]
|
|
72
95
|
if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
|
|
73
96
|
angles.append(angle)
|
|
74
97
|
elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
|
|
75
98
|
angles.append(angle - 90)
|
|
76
99
|
|
|
77
100
|
if len(angles) == 0:
|
|
78
|
-
|
|
101
|
+
estimated_angle = 0 # in case no angles is found
|
|
79
102
|
else:
|
|
80
103
|
median = -median_low(angles)
|
|
81
|
-
|
|
104
|
+
estimated_angle = -round(median) if abs(median) != 0 else 0
|
|
105
|
+
|
|
106
|
+
# combine with the general orientation and the estimated angle
|
|
107
|
+
if page_orientation and orientation_confidence >= min_confidence:
|
|
108
|
+
# special case where the estimated angle is mostly wrong:
|
|
109
|
+
# case 1: - and + swapped
|
|
110
|
+
# case 2: estimated angle is completely wrong
|
|
111
|
+
# so in this case we prefer the general page orientation
|
|
112
|
+
if abs(estimated_angle) == abs(page_orientation):
|
|
113
|
+
return page_orientation
|
|
114
|
+
estimated_angle = estimated_angle if page_orientation == 0 else page_orientation + estimated_angle
|
|
115
|
+
if estimated_angle > 180:
|
|
116
|
+
estimated_angle -= 360
|
|
117
|
+
|
|
118
|
+
return estimated_angle # return the clockwise angle (negative - left side rotation, positive - right side rotation)
|
|
82
119
|
|
|
83
120
|
|
|
84
121
|
def rectify_crops(
|
onnxtr/models/builder.py
CHANGED
|
@@ -31,7 +31,7 @@ class DocumentBuilder(NestedObject):
|
|
|
31
31
|
def __init__(
|
|
32
32
|
self,
|
|
33
33
|
resolve_lines: bool = True,
|
|
34
|
-
resolve_blocks: bool =
|
|
34
|
+
resolve_blocks: bool = False,
|
|
35
35
|
paragraph_break: float = 0.035,
|
|
36
36
|
export_as_straight_boxes: bool = False,
|
|
37
37
|
) -> None:
|
|
@@ -223,6 +223,7 @@ class DocumentBuilder(NestedObject):
|
|
|
223
223
|
def _build_blocks(
|
|
224
224
|
self,
|
|
225
225
|
boxes: np.ndarray,
|
|
226
|
+
objectness_scores: np.ndarray,
|
|
226
227
|
word_preds: List[Tuple[str, float]],
|
|
227
228
|
crop_orientations: List[Dict[str, Any]],
|
|
228
229
|
) -> List[Block]:
|
|
@@ -230,7 +231,8 @@ class DocumentBuilder(NestedObject):
|
|
|
230
231
|
|
|
231
232
|
Args:
|
|
232
233
|
----
|
|
233
|
-
boxes: bounding boxes of all detected words of the page, of shape (N,
|
|
234
|
+
boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
|
|
235
|
+
objectness_scores: objectness scores of all detected words of the page, of shape N
|
|
234
236
|
word_preds: list of all detected words of the page, of shape N
|
|
235
237
|
crop_orientations: list of dictoinaries containing
|
|
236
238
|
the general orientation (orientations + confidences) of the crops
|
|
@@ -265,12 +267,14 @@ class DocumentBuilder(NestedObject):
|
|
|
265
267
|
Word(
|
|
266
268
|
*word_preds[idx],
|
|
267
269
|
tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
|
|
270
|
+
float(objectness_scores[idx]),
|
|
268
271
|
crop_orientations[idx],
|
|
269
272
|
)
|
|
270
273
|
if boxes.ndim == 3
|
|
271
274
|
else Word(
|
|
272
275
|
*word_preds[idx],
|
|
273
276
|
((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
|
|
277
|
+
float(objectness_scores[idx]),
|
|
274
278
|
crop_orientations[idx],
|
|
275
279
|
)
|
|
276
280
|
for idx in line
|
|
@@ -293,6 +297,7 @@ class DocumentBuilder(NestedObject):
|
|
|
293
297
|
self,
|
|
294
298
|
pages: List[np.ndarray],
|
|
295
299
|
boxes: List[np.ndarray],
|
|
300
|
+
objectness_scores: List[np.ndarray],
|
|
296
301
|
text_preds: List[List[Tuple[str, float]]],
|
|
297
302
|
page_shapes: List[Tuple[int, int]],
|
|
298
303
|
crop_orientations: List[Dict[str, Any]],
|
|
@@ -304,8 +309,9 @@ class DocumentBuilder(NestedObject):
|
|
|
304
309
|
Args:
|
|
305
310
|
----
|
|
306
311
|
pages: list of N elements, where each element represents the page image
|
|
307
|
-
boxes: list of N elements, where each element represents the localization predictions, of shape (*,
|
|
308
|
-
or (*,
|
|
312
|
+
boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4)
|
|
313
|
+
or (*, 4, 2) for all words for a given page
|
|
314
|
+
objectness_scores: list of N elements, where each element represents the objectness scores
|
|
309
315
|
text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
|
|
310
316
|
page_shapes: shape of each page, of size N
|
|
311
317
|
crop_orientations: list of N elements, where each element is
|
|
@@ -319,9 +325,9 @@ class DocumentBuilder(NestedObject):
|
|
|
319
325
|
-------
|
|
320
326
|
document object
|
|
321
327
|
"""
|
|
322
|
-
if len(boxes) != len(text_preds) != len(crop_orientations)
|
|
323
|
-
|
|
324
|
-
):
|
|
328
|
+
if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
|
|
329
|
+
page_shapes
|
|
330
|
+
) != len(crop_orientations) != len(objectness_scores):
|
|
325
331
|
raise ValueError("All arguments are expected to be lists of the same size")
|
|
326
332
|
|
|
327
333
|
_orientations = (
|
|
@@ -339,6 +345,7 @@ class DocumentBuilder(NestedObject):
|
|
|
339
345
|
page,
|
|
340
346
|
self._build_blocks(
|
|
341
347
|
page_boxes,
|
|
348
|
+
loc_scores,
|
|
342
349
|
word_preds,
|
|
343
350
|
word_crop_orientations,
|
|
344
351
|
),
|
|
@@ -347,8 +354,16 @@ class DocumentBuilder(NestedObject):
|
|
|
347
354
|
orientation,
|
|
348
355
|
language,
|
|
349
356
|
)
|
|
350
|
-
for page, _idx, shape, page_boxes, word_preds, word_crop_orientations, orientation, language in zip(
|
|
351
|
-
pages,
|
|
357
|
+
for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501
|
|
358
|
+
pages,
|
|
359
|
+
range(len(boxes)),
|
|
360
|
+
page_shapes,
|
|
361
|
+
boxes,
|
|
362
|
+
objectness_scores,
|
|
363
|
+
text_preds,
|
|
364
|
+
crop_orientations,
|
|
365
|
+
_orientations,
|
|
366
|
+
_languages,
|
|
352
367
|
)
|
|
353
368
|
]
|
|
354
369
|
|
|
@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional
|
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
|
|
13
|
-
from ...engine import Engine
|
|
13
|
+
from ...engine import Engine, EngineConfig
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
16
|
"mobilenet_v3_small_crop_orientation",
|
|
@@ -24,6 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
24
24
|
"input_shape": (3, 256, 256),
|
|
25
25
|
"classes": [0, -90, 180, 90],
|
|
26
26
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/mobilenet_v3_small_crop_orientation-5620cf7e.onnx",
|
|
27
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/mobilenet_v3_small_crop_orientation_static_8_bit-4cfaa621.onnx",
|
|
27
28
|
},
|
|
28
29
|
"mobilenet_v3_small_page_orientation": {
|
|
29
30
|
"mean": (0.694, 0.695, 0.693),
|
|
@@ -31,6 +32,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
31
32
|
"input_shape": (3, 512, 512),
|
|
32
33
|
"classes": [0, -90, 180, 90],
|
|
33
34
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/mobilenet_v3_small_page_orientation-d3f76d79.onnx",
|
|
35
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/mobilenet_v3_small_page_orientation_static_8_bit-3e5ef3dc.onnx",
|
|
34
36
|
},
|
|
35
37
|
}
|
|
36
38
|
|
|
@@ -41,6 +43,7 @@ class MobileNetV3(Engine):
|
|
|
41
43
|
Args:
|
|
42
44
|
----
|
|
43
45
|
model_path: path or url to onnx model file
|
|
46
|
+
engine_cfg: configuration for the inference engine
|
|
44
47
|
cfg: configuration dictionary
|
|
45
48
|
**kwargs: additional arguments to be passed to `Engine`
|
|
46
49
|
"""
|
|
@@ -48,10 +51,11 @@ class MobileNetV3(Engine):
|
|
|
48
51
|
def __init__(
|
|
49
52
|
self,
|
|
50
53
|
model_path: str,
|
|
54
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
51
55
|
cfg: Optional[Dict[str, Any]] = None,
|
|
52
56
|
**kwargs: Any,
|
|
53
57
|
) -> None:
|
|
54
|
-
super().__init__(url=model_path, **kwargs)
|
|
58
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
55
59
|
self.cfg = cfg
|
|
56
60
|
|
|
57
61
|
def __call__(
|
|
@@ -64,14 +68,21 @@ class MobileNetV3(Engine):
|
|
|
64
68
|
def _mobilenet_v3(
|
|
65
69
|
arch: str,
|
|
66
70
|
model_path: str,
|
|
71
|
+
load_in_8_bit: bool = False,
|
|
72
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
67
73
|
**kwargs: Any,
|
|
68
74
|
) -> MobileNetV3:
|
|
75
|
+
# Patch the url
|
|
76
|
+
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
69
77
|
_cfg = deepcopy(default_cfgs[arch])
|
|
70
|
-
return MobileNetV3(model_path, cfg=_cfg, **kwargs)
|
|
78
|
+
return MobileNetV3(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
71
79
|
|
|
72
80
|
|
|
73
81
|
def mobilenet_v3_small_crop_orientation(
|
|
74
|
-
model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
|
|
82
|
+
model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
|
|
83
|
+
load_in_8_bit: bool = False,
|
|
84
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
85
|
+
**kwargs: Any,
|
|
75
86
|
) -> MobileNetV3:
|
|
76
87
|
"""MobileNetV3-Small architecture as described in
|
|
77
88
|
`"Searching for MobileNetV3",
|
|
@@ -86,17 +97,22 @@ def mobilenet_v3_small_crop_orientation(
|
|
|
86
97
|
Args:
|
|
87
98
|
----
|
|
88
99
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
100
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
101
|
+
engine_cfg: configuration for the inference engine
|
|
89
102
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
90
103
|
|
|
91
104
|
Returns:
|
|
92
105
|
-------
|
|
93
106
|
MobileNetV3
|
|
94
107
|
"""
|
|
95
|
-
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, **kwargs)
|
|
108
|
+
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
96
109
|
|
|
97
110
|
|
|
98
111
|
def mobilenet_v3_small_page_orientation(
|
|
99
|
-
model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
|
|
112
|
+
model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
|
|
113
|
+
load_in_8_bit: bool = False,
|
|
114
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
115
|
+
**kwargs: Any,
|
|
100
116
|
) -> MobileNetV3:
|
|
101
117
|
"""MobileNetV3-Small architecture as described in
|
|
102
118
|
`"Searching for MobileNetV3",
|
|
@@ -111,10 +127,12 @@ def mobilenet_v3_small_page_orientation(
|
|
|
111
127
|
Args:
|
|
112
128
|
----
|
|
113
129
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
130
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
131
|
+
engine_cfg: configuration for the inference engine
|
|
114
132
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
115
133
|
|
|
116
134
|
Returns:
|
|
117
135
|
-------
|
|
118
136
|
MobileNetV3
|
|
119
137
|
"""
|
|
120
|
-
return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, **kwargs)
|
|
138
|
+
return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -22,6 +22,7 @@ class OrientationPredictor(NestedObject):
|
|
|
22
22
|
----
|
|
23
23
|
pre_processor: transform inputs for easier batched model inference
|
|
24
24
|
model: core classification architecture (backbone + classification head)
|
|
25
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
25
26
|
"""
|
|
26
27
|
|
|
27
28
|
_children_names: List[str] = ["pre_processor", "model"]
|
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Any, List
|
|
7
7
|
|
|
8
|
+
from onnxtr.models.engine import EngineConfig
|
|
9
|
+
|
|
8
10
|
from .. import classification
|
|
9
11
|
from ..preprocessor import PreProcessor
|
|
10
12
|
from .predictor import OrientationPredictor
|
|
@@ -14,24 +16,30 @@ __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
|
|
|
14
16
|
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
15
17
|
|
|
16
18
|
|
|
17
|
-
def _orientation_predictor(
|
|
19
|
+
def _orientation_predictor(
|
|
20
|
+
arch: str, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any
|
|
21
|
+
) -> OrientationPredictor:
|
|
18
22
|
if arch not in ORIENTATION_ARCHS:
|
|
19
23
|
raise ValueError(f"unknown architecture '{arch}'")
|
|
20
24
|
|
|
21
25
|
# Load directly classifier from backbone
|
|
22
|
-
_model = classification.__dict__[arch]()
|
|
26
|
+
_model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg)
|
|
23
27
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
24
28
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
25
29
|
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
|
|
26
30
|
input_shape = _model.cfg["input_shape"][1:]
|
|
27
31
|
predictor = OrientationPredictor(
|
|
28
|
-
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs),
|
|
32
|
+
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs),
|
|
33
|
+
_model,
|
|
29
34
|
)
|
|
30
35
|
return predictor
|
|
31
36
|
|
|
32
37
|
|
|
33
38
|
def crop_orientation_predictor(
|
|
34
|
-
arch: Any = "mobilenet_v3_small_crop_orientation",
|
|
39
|
+
arch: Any = "mobilenet_v3_small_crop_orientation",
|
|
40
|
+
load_in_8_bit: bool = False,
|
|
41
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
42
|
+
**kwargs: Any,
|
|
35
43
|
) -> OrientationPredictor:
|
|
36
44
|
"""Crop orientation classification architecture.
|
|
37
45
|
|
|
@@ -44,17 +52,22 @@ def crop_orientation_predictor(
|
|
|
44
52
|
Args:
|
|
45
53
|
----
|
|
46
54
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
|
|
55
|
+
load_in_8_bit: load the 8-bit quantized version of the model
|
|
56
|
+
engine_cfg: configuration of inference engine
|
|
47
57
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
48
58
|
|
|
49
59
|
Returns:
|
|
50
60
|
-------
|
|
51
61
|
OrientationPredictor
|
|
52
62
|
"""
|
|
53
|
-
return _orientation_predictor(arch, **kwargs)
|
|
63
|
+
return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
|
|
54
64
|
|
|
55
65
|
|
|
56
66
|
def page_orientation_predictor(
|
|
57
|
-
arch: Any = "mobilenet_v3_small_page_orientation",
|
|
67
|
+
arch: Any = "mobilenet_v3_small_page_orientation",
|
|
68
|
+
load_in_8_bit: bool = False,
|
|
69
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
70
|
+
**kwargs: Any,
|
|
58
71
|
) -> OrientationPredictor:
|
|
59
72
|
"""Page orientation classification architecture.
|
|
60
73
|
|
|
@@ -67,10 +80,12 @@ def page_orientation_predictor(
|
|
|
67
80
|
Args:
|
|
68
81
|
----
|
|
69
82
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
|
|
83
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
84
|
+
engine_cfg: configuration for the inference engine
|
|
70
85
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
71
86
|
|
|
72
87
|
Returns:
|
|
73
88
|
-------
|
|
74
89
|
OrientationPredictor
|
|
75
90
|
"""
|
|
76
|
-
return _orientation_predictor(arch, **kwargs)
|
|
91
|
+
return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . base import *
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
__all__ = ["_remove_padding"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _remove_padding(
|
|
14
|
+
pages: List[np.ndarray],
|
|
15
|
+
loc_preds: List[np.ndarray],
|
|
16
|
+
preserve_aspect_ratio: bool,
|
|
17
|
+
symmetric_pad: bool,
|
|
18
|
+
assume_straight_pages: bool,
|
|
19
|
+
) -> List[np.ndarray]:
|
|
20
|
+
"""Remove padding from the localization predictions
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
----
|
|
24
|
+
|
|
25
|
+
pages: list of pages
|
|
26
|
+
loc_preds: list of localization predictions
|
|
27
|
+
preserve_aspect_ratio: whether the aspect ratio was preserved during padding
|
|
28
|
+
symmetric_pad: whether the padding was symmetric
|
|
29
|
+
assume_straight_pages: whether the pages are assumed to be straight
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
-------
|
|
33
|
+
list of unpaded localization predictions
|
|
34
|
+
"""
|
|
35
|
+
if preserve_aspect_ratio:
|
|
36
|
+
# Rectify loc_preds to remove padding
|
|
37
|
+
rectified_preds = []
|
|
38
|
+
for page, loc_pred in zip(pages, loc_preds):
|
|
39
|
+
h, w = page.shape[0], page.shape[1]
|
|
40
|
+
if h > w:
|
|
41
|
+
# y unchanged, dilate x coord
|
|
42
|
+
if symmetric_pad:
|
|
43
|
+
if assume_straight_pages:
|
|
44
|
+
loc_pred[:, [0, 2]] = (loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5
|
|
45
|
+
else:
|
|
46
|
+
loc_pred[:, :, 0] = (loc_pred[:, :, 0] - 0.5) * h / w + 0.5
|
|
47
|
+
else:
|
|
48
|
+
if assume_straight_pages:
|
|
49
|
+
loc_pred[:, [0, 2]] *= h / w
|
|
50
|
+
else:
|
|
51
|
+
loc_pred[:, :, 0] *= h / w
|
|
52
|
+
elif w > h:
|
|
53
|
+
# x unchanged, dilate y coord
|
|
54
|
+
if symmetric_pad:
|
|
55
|
+
if assume_straight_pages:
|
|
56
|
+
loc_pred[:, [1, 3]] = (loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5
|
|
57
|
+
else:
|
|
58
|
+
loc_pred[:, :, 1] = (loc_pred[:, :, 1] - 0.5) * w / h + 0.5
|
|
59
|
+
else:
|
|
60
|
+
if assume_straight_pages:
|
|
61
|
+
loc_pred[:, [1, 3]] *= w / h
|
|
62
|
+
else:
|
|
63
|
+
loc_pred[:, :, 1] *= w / h
|
|
64
|
+
rectified_preds.append(np.clip(loc_pred, 0, 1))
|
|
65
|
+
return rectified_preds
|
|
66
|
+
return loc_preds
|