onnxtr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. onnxtr/__init__.py +2 -0
  2. onnxtr/contrib/__init__.py +0 -0
  3. onnxtr/contrib/artefacts.py +131 -0
  4. onnxtr/contrib/base.py +105 -0
  5. onnxtr/file_utils.py +33 -0
  6. onnxtr/io/__init__.py +5 -0
  7. onnxtr/io/elements.py +455 -0
  8. onnxtr/io/html.py +28 -0
  9. onnxtr/io/image.py +56 -0
  10. onnxtr/io/pdf.py +42 -0
  11. onnxtr/io/reader.py +85 -0
  12. onnxtr/models/__init__.py +4 -0
  13. onnxtr/models/_utils.py +141 -0
  14. onnxtr/models/builder.py +355 -0
  15. onnxtr/models/classification/__init__.py +2 -0
  16. onnxtr/models/classification/models/__init__.py +1 -0
  17. onnxtr/models/classification/models/mobilenet.py +120 -0
  18. onnxtr/models/classification/predictor/__init__.py +1 -0
  19. onnxtr/models/classification/predictor/base.py +57 -0
  20. onnxtr/models/classification/zoo.py +76 -0
  21. onnxtr/models/detection/__init__.py +2 -0
  22. onnxtr/models/detection/core.py +101 -0
  23. onnxtr/models/detection/models/__init__.py +3 -0
  24. onnxtr/models/detection/models/differentiable_binarization.py +159 -0
  25. onnxtr/models/detection/models/fast.py +160 -0
  26. onnxtr/models/detection/models/linknet.py +160 -0
  27. onnxtr/models/detection/postprocessor/__init__.py +0 -0
  28. onnxtr/models/detection/postprocessor/base.py +144 -0
  29. onnxtr/models/detection/predictor/__init__.py +1 -0
  30. onnxtr/models/detection/predictor/base.py +54 -0
  31. onnxtr/models/detection/zoo.py +73 -0
  32. onnxtr/models/engine.py +50 -0
  33. onnxtr/models/predictor/__init__.py +1 -0
  34. onnxtr/models/predictor/base.py +175 -0
  35. onnxtr/models/predictor/predictor.py +145 -0
  36. onnxtr/models/preprocessor/__init__.py +1 -0
  37. onnxtr/models/preprocessor/base.py +118 -0
  38. onnxtr/models/recognition/__init__.py +2 -0
  39. onnxtr/models/recognition/core.py +28 -0
  40. onnxtr/models/recognition/models/__init__.py +5 -0
  41. onnxtr/models/recognition/models/crnn.py +226 -0
  42. onnxtr/models/recognition/models/master.py +145 -0
  43. onnxtr/models/recognition/models/parseq.py +134 -0
  44. onnxtr/models/recognition/models/sar.py +134 -0
  45. onnxtr/models/recognition/models/vitstr.py +166 -0
  46. onnxtr/models/recognition/predictor/__init__.py +1 -0
  47. onnxtr/models/recognition/predictor/_utils.py +86 -0
  48. onnxtr/models/recognition/predictor/base.py +79 -0
  49. onnxtr/models/recognition/utils.py +89 -0
  50. onnxtr/models/recognition/zoo.py +69 -0
  51. onnxtr/models/zoo.py +114 -0
  52. onnxtr/transforms/__init__.py +1 -0
  53. onnxtr/transforms/base.py +112 -0
  54. onnxtr/utils/__init__.py +4 -0
  55. onnxtr/utils/common_types.py +18 -0
  56. onnxtr/utils/data.py +126 -0
  57. onnxtr/utils/fonts.py +41 -0
  58. onnxtr/utils/geometry.py +498 -0
  59. onnxtr/utils/multithreading.py +50 -0
  60. onnxtr/utils/reconstitution.py +70 -0
  61. onnxtr/utils/repr.py +64 -0
  62. onnxtr/utils/visualization.py +291 -0
  63. onnxtr/utils/vocabs.py +71 -0
  64. onnxtr/version.py +1 -0
  65. onnxtr-0.1.0.dist-info/LICENSE +201 -0
  66. onnxtr-0.1.0.dist-info/METADATA +481 -0
  67. onnxtr-0.1.0.dist-info/RECORD +70 -0
  68. onnxtr-0.1.0.dist-info/WHEEL +5 -0
  69. onnxtr-0.1.0.dist-info/top_level.txt +2 -0
  70. onnxtr-0.1.0.dist-info/zip-safe +1 -0
@@ -0,0 +1,141 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from math import floor
7
+ from statistics import median_low
8
+ from typing import List, Optional, Tuple
9
+
10
+ import cv2
11
+ import numpy as np
12
+ from langdetect import LangDetectException, detect_langs
13
+
14
+ __all__ = ["estimate_orientation", "get_language"]
15
+
16
+
17
+ def get_max_width_length_ratio(contour: np.ndarray) -> float:
18
+ """Get the maximum shape ratio of a contour.
19
+
20
+ Args:
21
+ ----
22
+ contour: the contour from cv2.findContour
23
+
24
+ Returns:
25
+ -------
26
+ the maximum shape ratio
27
+ """
28
+ _, (w, h), _ = cv2.minAreaRect(contour)
29
+ return max(w / h, h / w)
30
+
31
+
32
+ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int:
33
+ """Estimate the angle of the general document orientation based on the
34
+ lines of the document and the assumption that they should be horizontal.
35
+
36
+ Args:
37
+ ----
38
+ img: the img or bitmap to analyze (H, W, C)
39
+ n_ct: the number of contours used for the orientation estimation
40
+ ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
41
+
42
+ Returns:
43
+ -------
44
+ the angle of the general document orientation
45
+ """
46
+ assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
47
+ max_value = np.max(img)
48
+ min_value = np.min(img)
49
+ if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1):
50
+ thresh = img.astype(np.uint8)
51
+ if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
52
+ gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
53
+ gray_img = cv2.medianBlur(gray_img, 5)
54
+ thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # type: ignore[assignment]
55
+
56
+ # try to merge words in lines
57
+ (h, w) = img.shape[:2]
58
+ k_x = max(1, (floor(w / 100)))
59
+ k_y = max(1, (floor(h / 100)))
60
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
61
+ thresh = cv2.dilate(thresh, kernel, iterations=1) # type: ignore[assignment]
62
+
63
+ # extract contours
64
+ contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
65
+
66
+ # Sort contours
67
+ contours = sorted(contours, key=get_max_width_length_ratio, reverse=True)
68
+
69
+ angles = []
70
+ for contour in contours[:n_ct]:
71
+ _, (w, h), angle = cv2.minAreaRect(contour)
72
+ if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
73
+ angles.append(angle)
74
+ elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
75
+ angles.append(angle - 90)
76
+
77
+ if len(angles) == 0:
78
+ return 0 # in case no angles is found
79
+ else:
80
+ median = -median_low(angles)
81
+ return round(median) if abs(median) != 0 else 0
82
+
83
+
84
+ def rectify_crops(
85
+ crops: List[np.ndarray],
86
+ orientations: List[int],
87
+ ) -> List[np.ndarray]:
88
+ """Rotate each crop of the list according to the predicted orientation:
89
+ 0: already straight, no rotation
90
+ 1: 90 ccw, rotate 3 times ccw
91
+ 2: 180, rotate 2 times ccw
92
+ 3: 270 ccw, rotate 1 time ccw
93
+ """
94
+ # Inverse predictions (if angle of +90 is detected, rotate by -90)
95
+ orientations = [4 - pred if pred != 0 else 0 for pred in orientations]
96
+ return (
97
+ [crop if orientation == 0 else np.rot90(crop, orientation) for orientation, crop in zip(orientations, crops)]
98
+ if len(orientations) > 0
99
+ else []
100
+ )
101
+
102
+
103
+ def rectify_loc_preds(
104
+ page_loc_preds: np.ndarray,
105
+ orientations: List[int],
106
+ ) -> Optional[np.ndarray]:
107
+ """Orient the quadrangle (Polygon4P) according to the predicted orientation,
108
+ so that the points are in this order: top L, top R, bot R, bot L if the crop is readable
109
+ """
110
+ return (
111
+ np.stack(
112
+ [
113
+ np.roll(page_loc_pred, orientation, axis=0)
114
+ for orientation, page_loc_pred in zip(orientations, page_loc_preds)
115
+ ],
116
+ axis=0,
117
+ )
118
+ if len(orientations) > 0
119
+ else None
120
+ )
121
+
122
+
123
+ def get_language(text: str) -> Tuple[str, float]:
124
+ """Get languages of a text using langdetect model.
125
+ Get the language with the highest probability or no language if only a few words or a low probability
126
+
127
+ Args:
128
+ ----
129
+ text (str): text
130
+
131
+ Returns:
132
+ -------
133
+ The detected language in ISO 639 code and confidence score
134
+ """
135
+ try:
136
+ lang = detect_langs(text.lower())[0]
137
+ except LangDetectException:
138
+ return "unknown", 0.0
139
+ if len(text) <= 1 or (len(text) <= 5 and lang.prob <= 0.2):
140
+ return "unknown", 0.0
141
+ return lang.lang, lang.prob
@@ -0,0 +1,355 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ from scipy.cluster.hierarchy import fclusterdata
11
+
12
+ from onnxtr.io.elements import Block, Document, Line, Page, Word
13
+ from onnxtr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes
14
+ from onnxtr.utils.repr import NestedObject
15
+
16
+ __all__ = ["DocumentBuilder"]
17
+
18
+
19
+ class DocumentBuilder(NestedObject):
20
+ """Implements a document builder
21
+
22
+ Args:
23
+ ----
24
+ resolve_lines: whether words should be automatically grouped into lines
25
+ resolve_blocks: whether lines should be automatically grouped into blocks
26
+ paragraph_break: relative length of the minimum space separating paragraphs
27
+ export_as_straight_boxes: if True, force straight boxes in the export (fit a rectangle
28
+ box to all rotated boxes). Else, keep the boxes format unchanged, no matter what it is.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ resolve_lines: bool = True,
34
+ resolve_blocks: bool = True,
35
+ paragraph_break: float = 0.035,
36
+ export_as_straight_boxes: bool = False,
37
+ ) -> None:
38
+ self.resolve_lines = resolve_lines
39
+ self.resolve_blocks = resolve_blocks
40
+ self.paragraph_break = paragraph_break
41
+ self.export_as_straight_boxes = export_as_straight_boxes
42
+
43
+ @staticmethod
44
+ def _sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
45
+ """Sort bounding boxes from top to bottom, left to right
46
+
47
+ Args:
48
+ ----
49
+ boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox)
50
+
51
+ Returns:
52
+ -------
53
+ tuple: indices of ordered boxes of shape (N,), boxes
54
+ If straight boxes are passed tpo the function, boxes are unchanged
55
+ else: boxes returned are straight boxes fitted to the straightened rotated boxes
56
+ so that we fit the lines afterwards to the straigthened page
57
+ """
58
+ if boxes.ndim == 3:
59
+ boxes = rotate_boxes(
60
+ loc_preds=boxes,
61
+ angle=-estimate_page_angle(boxes),
62
+ orig_shape=(1024, 1024),
63
+ min_angle=5.0,
64
+ )
65
+ boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1)
66
+ return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes
67
+
68
+ def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[List[int]]:
69
+ """Split a line in sub_lines
70
+
71
+ Args:
72
+ ----
73
+ boxes: bounding boxes of shape (N, 4)
74
+ word_idcs: list of indexes for the words of the line
75
+
76
+ Returns:
77
+ -------
78
+ A list of (sub-)lines computed from the original line (words)
79
+ """
80
+ lines = []
81
+ # Sort words horizontally
82
+ word_idcs = [word_idcs[idx] for idx in boxes[word_idcs, 0].argsort().tolist()]
83
+
84
+ # Eventually split line horizontally
85
+ if len(word_idcs) < 2:
86
+ lines.append(word_idcs)
87
+ else:
88
+ sub_line = [word_idcs[0]]
89
+ for i in word_idcs[1:]:
90
+ horiz_break = True
91
+
92
+ prev_box = boxes[sub_line[-1]]
93
+ # Compute distance between boxes
94
+ dist = boxes[i, 0] - prev_box[2]
95
+ # If distance between boxes is lower than paragraph break, same sub-line
96
+ if dist < self.paragraph_break:
97
+ horiz_break = False
98
+
99
+ if horiz_break:
100
+ lines.append(sub_line)
101
+ sub_line = []
102
+
103
+ sub_line.append(i)
104
+ lines.append(sub_line)
105
+
106
+ return lines
107
+
108
+ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]:
109
+ """Order boxes to group them in lines
110
+
111
+ Args:
112
+ ----
113
+ boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox
114
+
115
+ Returns:
116
+ -------
117
+ nested list of box indices
118
+ """
119
+ # Sort boxes, and straighten the boxes if they are rotated
120
+ idxs, boxes = self._sort_boxes(boxes)
121
+
122
+ # Compute median for boxes heights
123
+ y_med = np.median(boxes[:, 3] - boxes[:, 1])
124
+
125
+ lines = []
126
+ words = [idxs[0]] # Assign the top-left word to the first line
127
+ # Define a mean y-center for the line
128
+ y_center_sum = boxes[idxs[0]][[1, 3]].mean()
129
+
130
+ for idx in idxs[1:]:
131
+ vert_break = True
132
+
133
+ # Compute y_dist
134
+ y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words))
135
+ # If y-center of the box is close enough to mean y-center of the line, same line
136
+ if y_dist < y_med / 2:
137
+ vert_break = False
138
+
139
+ if vert_break:
140
+ # Compute sub-lines (horizontal split)
141
+ lines.extend(self._resolve_sub_lines(boxes, words))
142
+ words = []
143
+ y_center_sum = 0
144
+
145
+ words.append(idx)
146
+ y_center_sum += boxes[idx][[1, 3]].mean()
147
+
148
+ # Use the remaining words to form the last(s) line(s)
149
+ if len(words) > 0:
150
+ # Compute sub-lines (horizontal split)
151
+ lines.extend(self._resolve_sub_lines(boxes, words))
152
+
153
+ return lines
154
+
155
+ @staticmethod
156
+ def _resolve_blocks(boxes: np.ndarray, lines: List[List[int]]) -> List[List[List[int]]]:
157
+ """Order lines to group them in blocks
158
+
159
+ Args:
160
+ ----
161
+ boxes: bounding boxes of shape (N, 4) or (N, 4, 2)
162
+ lines: list of lines, each line is a list of idx
163
+
164
+ Returns:
165
+ -------
166
+ nested list of box indices
167
+ """
168
+ # Resolve enclosing boxes of lines
169
+ if boxes.ndim == 3:
170
+ box_lines: np.ndarray = np.asarray([
171
+ resolve_enclosing_rbbox([tuple(boxes[idx, :, :]) for idx in line]) # type: ignore[misc]
172
+ for line in lines
173
+ ])
174
+ else:
175
+ _box_lines = [
176
+ resolve_enclosing_bbox([(tuple(boxes[idx, :2]), tuple(boxes[idx, 2:])) for idx in line])
177
+ for line in lines
178
+ ]
179
+ box_lines = np.asarray([(x1, y1, x2, y2) for ((x1, y1), (x2, y2)) in _box_lines])
180
+
181
+ # Compute geometrical features of lines to clusterize
182
+ # Clusterizing only with box centers yield to poor results for complex documents
183
+ if boxes.ndim == 3:
184
+ box_features: np.ndarray = np.stack(
185
+ (
186
+ (box_lines[:, 0, 0] + box_lines[:, 0, 1]) / 2,
187
+ (box_lines[:, 0, 0] + box_lines[:, 2, 0]) / 2,
188
+ (box_lines[:, 0, 0] + box_lines[:, 2, 1]) / 2,
189
+ (box_lines[:, 0, 1] + box_lines[:, 2, 1]) / 2,
190
+ (box_lines[:, 0, 1] + box_lines[:, 2, 0]) / 2,
191
+ (box_lines[:, 2, 0] + box_lines[:, 2, 1]) / 2,
192
+ ),
193
+ axis=-1,
194
+ )
195
+ else:
196
+ box_features = np.stack(
197
+ (
198
+ (box_lines[:, 0] + box_lines[:, 3]) / 2,
199
+ (box_lines[:, 1] + box_lines[:, 2]) / 2,
200
+ (box_lines[:, 0] + box_lines[:, 2]) / 2,
201
+ (box_lines[:, 1] + box_lines[:, 3]) / 2,
202
+ box_lines[:, 0],
203
+ box_lines[:, 1],
204
+ ),
205
+ axis=-1,
206
+ )
207
+ # Compute clusters
208
+ clusters = fclusterdata(box_features, t=0.1, depth=4, criterion="distance", metric="euclidean")
209
+
210
+ _blocks: Dict[int, List[int]] = {}
211
+ # Form clusters
212
+ for line_idx, cluster_idx in enumerate(clusters):
213
+ if cluster_idx in _blocks.keys():
214
+ _blocks[cluster_idx].append(line_idx)
215
+ else:
216
+ _blocks[cluster_idx] = [line_idx]
217
+
218
+ # Retrieve word-box level to return a fully nested structure
219
+ blocks = [[lines[idx] for idx in block] for block in _blocks.values()]
220
+
221
+ return blocks
222
+
223
+ def _build_blocks(
224
+ self,
225
+ boxes: np.ndarray,
226
+ word_preds: List[Tuple[str, float]],
227
+ crop_orientations: List[Dict[str, Any]],
228
+ ) -> List[Block]:
229
+ """Gather independent words in structured blocks
230
+
231
+ Args:
232
+ ----
233
+ boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
234
+ word_preds: list of all detected words of the page, of shape N
235
+ crop_orientations: list of dictoinaries containing
236
+ the general orientation (orientations + confidences) of the crops
237
+
238
+ Returns:
239
+ -------
240
+ list of block elements
241
+ """
242
+ if boxes.shape[0] != len(word_preds):
243
+ raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}")
244
+
245
+ if boxes.shape[0] == 0:
246
+ return []
247
+
248
+ # Decide whether we try to form lines
249
+ _boxes = boxes
250
+ if self.resolve_lines:
251
+ lines = self._resolve_lines(_boxes if _boxes.ndim == 3 else _boxes[:, :4])
252
+ # Decide whether we try to form blocks
253
+ if self.resolve_blocks and len(lines) > 1:
254
+ _blocks = self._resolve_blocks(_boxes if _boxes.ndim == 3 else _boxes[:, :4], lines)
255
+ else:
256
+ _blocks = [lines]
257
+ else:
258
+ # Sort bounding boxes, one line for all boxes, one block for the line
259
+ lines = [self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4])[0]] # type: ignore[list-item]
260
+ _blocks = [lines]
261
+
262
+ blocks = [
263
+ Block([
264
+ Line([
265
+ Word(
266
+ *word_preds[idx],
267
+ tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
268
+ crop_orientations[idx],
269
+ )
270
+ if boxes.ndim == 3
271
+ else Word(
272
+ *word_preds[idx],
273
+ ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
274
+ crop_orientations[idx],
275
+ )
276
+ for idx in line
277
+ ])
278
+ for line in lines
279
+ ])
280
+ for lines in _blocks
281
+ ]
282
+
283
+ return blocks
284
+
285
+ def extra_repr(self) -> str:
286
+ return (
287
+ f"resolve_lines={self.resolve_lines}, resolve_blocks={self.resolve_blocks}, "
288
+ f"paragraph_break={self.paragraph_break}, "
289
+ f"export_as_straight_boxes={self.export_as_straight_boxes}"
290
+ )
291
+
292
+ def __call__(
293
+ self,
294
+ pages: List[np.ndarray],
295
+ boxes: List[np.ndarray],
296
+ text_preds: List[List[Tuple[str, float]]],
297
+ page_shapes: List[Tuple[int, int]],
298
+ crop_orientations: List[Dict[str, Any]],
299
+ orientations: Optional[List[Dict[str, Any]]] = None,
300
+ languages: Optional[List[Dict[str, Any]]] = None,
301
+ ) -> Document:
302
+ """Re-arrange detected words into structured blocks
303
+
304
+ Args:
305
+ ----
306
+ pages: list of N elements, where each element represents the page image
307
+ boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
308
+ or (*, 6) for all words for a given page
309
+ text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
310
+ page_shapes: shape of each page, of size N
311
+ crop_orientations: list of N elements, where each element is
312
+ a dictionary containing the general orientation (orientations + confidences) of the crops
313
+ orientations: optional, list of N elements,
314
+ where each element is a dictionary containing the orientation (orientation + confidence)
315
+ languages: optional, list of N elements,
316
+ where each element is a dictionary containing the language (language + confidence)
317
+
318
+ Returns:
319
+ -------
320
+ document object
321
+ """
322
+ if len(boxes) != len(text_preds) != len(crop_orientations) or len(boxes) != len(page_shapes) != len(
323
+ crop_orientations
324
+ ):
325
+ raise ValueError("All arguments are expected to be lists of the same size")
326
+
327
+ _orientations = (
328
+ orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item]
329
+ )
330
+ _languages = languages if isinstance(languages, list) else [None] * len(boxes) # type: ignore[list-item]
331
+ if self.export_as_straight_boxes and len(boxes) > 0:
332
+ # If boxes are already straight OK, else fit a bounding rect
333
+ if boxes[0].ndim == 3:
334
+ # Iterate over pages and boxes
335
+ boxes = [np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1) for p_boxes in boxes]
336
+
337
+ _pages = [
338
+ Page(
339
+ page,
340
+ self._build_blocks(
341
+ page_boxes,
342
+ word_preds,
343
+ word_crop_orientations,
344
+ ),
345
+ _idx,
346
+ shape,
347
+ orientation,
348
+ language,
349
+ )
350
+ for page, _idx, shape, page_boxes, word_preds, word_crop_orientations, orientation, language in zip(
351
+ pages, range(len(boxes)), page_shapes, boxes, text_preds, crop_orientations, _orientations, _languages
352
+ )
353
+ ]
354
+
355
+ return Document(_pages)
@@ -0,0 +1,2 @@
1
+ from .models import *
2
+ from .zoo import *
@@ -0,0 +1 @@
1
+ from .mobilenet import *
@@ -0,0 +1,120 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
7
+
8
+ from copy import deepcopy
9
+ from typing import Any, Dict, Optional
10
+
11
+ import numpy as np
12
+
13
+ from ...engine import Engine
14
+
15
+ __all__ = [
16
+ "mobilenet_v3_small_crop_orientation",
17
+ "mobilenet_v3_small_page_orientation",
18
+ ]
19
+
20
+ default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ "mobilenet_v3_small_crop_orientation": {
22
+ "mean": (0.694, 0.695, 0.693),
23
+ "std": (0.299, 0.296, 0.301),
24
+ "input_shape": (3, 256, 256),
25
+ "classes": [0, -90, 180, 90],
26
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/mobilenet_v3_small_crop_orientation-5620cf7e.onnx",
27
+ },
28
+ "mobilenet_v3_small_page_orientation": {
29
+ "mean": (0.694, 0.695, 0.693),
30
+ "std": (0.299, 0.296, 0.301),
31
+ "input_shape": (3, 512, 512),
32
+ "classes": [0, -90, 180, 90],
33
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/mobilenet_v3_small_page_orientation-d3f76d79.onnx",
34
+ },
35
+ }
36
+
37
+
38
+ class MobileNetV3(Engine):
39
+ """MobileNetV3 Onnx loader
40
+
41
+ Args:
42
+ ----
43
+ model_path: path or url to onnx model file
44
+ cfg: configuration dictionary
45
+ **kwargs: additional arguments to be passed to `Engine`
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ model_path: str,
51
+ cfg: Optional[Dict[str, Any]] = None,
52
+ **kwargs: Any,
53
+ ) -> None:
54
+ super().__init__(url=model_path, **kwargs)
55
+ self.cfg = cfg
56
+
57
+ def __call__(
58
+ self,
59
+ x: np.ndarray,
60
+ ) -> np.ndarray:
61
+ return self.run(x)
62
+
63
+
64
+ def _mobilenet_v3(
65
+ arch: str,
66
+ model_path: str,
67
+ **kwargs: Any,
68
+ ) -> MobileNetV3:
69
+ _cfg = deepcopy(default_cfgs[arch])
70
+ return MobileNetV3(model_path, cfg=_cfg, **kwargs)
71
+
72
+
73
+ def mobilenet_v3_small_crop_orientation(
74
+ model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"], **kwargs: Any
75
+ ) -> MobileNetV3:
76
+ """MobileNetV3-Small architecture as described in
77
+ `"Searching for MobileNetV3",
78
+ <https://arxiv.org/pdf/1905.02244.pdf>`_.
79
+
80
+ >>> import numpy as np
81
+ >>> from onnxtr.models import mobilenet_v3_small_crop_orientation
82
+ >>> model = mobilenet_v3_small_crop_orientation()
83
+ >>> input_tensor = np.random.rand((1, 3, 256, 256))
84
+ >>> out = model(input_tensor)
85
+
86
+ Args:
87
+ ----
88
+ model_path: path to onnx model file, defaults to url in default_cfgs
89
+ **kwargs: keyword arguments of the MobileNetV3 architecture
90
+
91
+ Returns:
92
+ -------
93
+ MobileNetV3
94
+ """
95
+ return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, **kwargs)
96
+
97
+
98
+ def mobilenet_v3_small_page_orientation(
99
+ model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"], **kwargs: Any
100
+ ) -> MobileNetV3:
101
+ """MobileNetV3-Small architecture as described in
102
+ `"Searching for MobileNetV3",
103
+ <https://arxiv.org/pdf/1905.02244.pdf>`_.
104
+
105
+ >>> import numpy as np
106
+ >>> from onnxtr.models import mobilenet_v3_small_page_orientation
107
+ >>> model = mobilenet_v3_small_page_orientation()
108
+ >>> input_tensor = np.random.rand((1, 3, 512, 512))
109
+ >>> out = model(input_tensor)
110
+
111
+ Args:
112
+ ----
113
+ model_path: path to onnx model file, defaults to url in default_cfgs
114
+ **kwargs: keyword arguments of the MobileNetV3 architecture
115
+
116
+ Returns:
117
+ -------
118
+ MobileNetV3
119
+ """
120
+ return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, **kwargs)
@@ -0,0 +1 @@
1
+ from .base import *
@@ -0,0 +1,57 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List, Union
7
+
8
+ import numpy as np
9
+ from scipy.special import softmax
10
+
11
+ from onnxtr.models.preprocessor import PreProcessor
12
+ from onnxtr.utils.repr import NestedObject
13
+
14
+ __all__ = ["OrientationPredictor"]
15
+
16
+
17
+ class OrientationPredictor(NestedObject):
18
+ """Implements an object able to detect the reading direction of a text box or a page.
19
+ 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
20
+
21
+ Args:
22
+ ----
23
+ pre_processor: transform inputs for easier batched model inference
24
+ model: core classification architecture (backbone + classification head)
25
+ """
26
+
27
+ _children_names: List[str] = ["pre_processor", "model"]
28
+
29
+ def __init__(
30
+ self,
31
+ pre_processor: PreProcessor,
32
+ model: Any,
33
+ ) -> None:
34
+ self.pre_processor = pre_processor
35
+ self.model = model
36
+
37
+ def __call__(
38
+ self,
39
+ inputs: List[np.ndarray],
40
+ ) -> List[Union[List[int], List[float]]]:
41
+ # Dimension check
42
+ if any(input.ndim != 3 for input in inputs):
43
+ raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
44
+
45
+ processed_batches = self.pre_processor(inputs)
46
+ predicted_batches = [self.model(batch) for batch in processed_batches]
47
+
48
+ # confidence
49
+ probs = [np.max(softmax(batch, axis=1), axis=1) for batch in predicted_batches]
50
+ # Postprocess predictions
51
+ predicted_batches = [np.argmax(out_batch, axis=1) for out_batch in predicted_batches]
52
+
53
+ class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
54
+ classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
55
+ confs = [round(float(p), 2) for prob in probs for p in prob]
56
+
57
+ return [class_idxs, classes, confs]