onnxtr 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. onnxtr/contrib/__init__.py +1 -0
  2. onnxtr/contrib/artefacts.py +6 -8
  3. onnxtr/contrib/base.py +7 -16
  4. onnxtr/file_utils.py +1 -3
  5. onnxtr/io/elements.py +54 -60
  6. onnxtr/io/html.py +0 -2
  7. onnxtr/io/image.py +1 -4
  8. onnxtr/io/pdf.py +3 -5
  9. onnxtr/io/reader.py +4 -10
  10. onnxtr/models/_utils.py +10 -17
  11. onnxtr/models/builder.py +17 -30
  12. onnxtr/models/classification/models/mobilenet.py +7 -12
  13. onnxtr/models/classification/predictor/base.py +6 -7
  14. onnxtr/models/classification/zoo.py +25 -11
  15. onnxtr/models/detection/_utils/base.py +3 -7
  16. onnxtr/models/detection/core.py +2 -8
  17. onnxtr/models/detection/models/differentiable_binarization.py +10 -17
  18. onnxtr/models/detection/models/fast.py +10 -17
  19. onnxtr/models/detection/models/linknet.py +10 -17
  20. onnxtr/models/detection/postprocessor/base.py +3 -9
  21. onnxtr/models/detection/predictor/base.py +4 -5
  22. onnxtr/models/detection/zoo.py +20 -6
  23. onnxtr/models/engine.py +9 -9
  24. onnxtr/models/factory/hub.py +3 -7
  25. onnxtr/models/predictor/base.py +29 -30
  26. onnxtr/models/predictor/predictor.py +4 -5
  27. onnxtr/models/preprocessor/base.py +8 -12
  28. onnxtr/models/recognition/core.py +0 -1
  29. onnxtr/models/recognition/models/crnn.py +11 -23
  30. onnxtr/models/recognition/models/master.py +9 -15
  31. onnxtr/models/recognition/models/parseq.py +8 -12
  32. onnxtr/models/recognition/models/sar.py +8 -12
  33. onnxtr/models/recognition/models/vitstr.py +9 -15
  34. onnxtr/models/recognition/predictor/_utils.py +6 -9
  35. onnxtr/models/recognition/predictor/base.py +3 -3
  36. onnxtr/models/recognition/utils.py +2 -7
  37. onnxtr/models/recognition/zoo.py +19 -7
  38. onnxtr/models/zoo.py +7 -9
  39. onnxtr/transforms/base.py +17 -6
  40. onnxtr/utils/common_types.py +7 -8
  41. onnxtr/utils/data.py +7 -11
  42. onnxtr/utils/fonts.py +1 -6
  43. onnxtr/utils/geometry.py +18 -49
  44. onnxtr/utils/multithreading.py +3 -5
  45. onnxtr/utils/reconstitution.py +139 -38
  46. onnxtr/utils/repr.py +1 -2
  47. onnxtr/utils/visualization.py +12 -21
  48. onnxtr/utils/vocabs.py +1 -2
  49. onnxtr/version.py +1 -1
  50. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/METADATA +71 -41
  51. onnxtr-0.6.0.dist-info/RECORD +75 -0
  52. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/WHEEL +1 -1
  53. onnxtr-0.5.0.dist-info/RECORD +0 -75
  54. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/LICENSE +0 -0
  55. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/top_level.txt +0 -0
  56. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/zip-safe +0 -0
onnxtr/models/_utils.py CHANGED
@@ -5,7 +5,6 @@
5
5
 
6
6
  from math import floor
7
7
  from statistics import median_low
8
- from typing import List, Optional, Tuple
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -20,11 +19,9 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
20
19
  """Get the maximum shape ratio of a contour.
21
20
 
22
21
  Args:
23
- ----
24
22
  contour: the contour from cv2.findContour
25
23
 
26
24
  Returns:
27
- -------
28
25
  the maximum shape ratio
29
26
  """
30
27
  _, (w, h), _ = cv2.minAreaRect(contour)
@@ -33,7 +30,7 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
33
30
 
34
31
  def estimate_orientation(
35
32
  img: np.ndarray,
36
- general_page_orientation: Optional[Tuple[int, float]] = None,
33
+ general_page_orientation: tuple[int, float] | None = None,
37
34
  n_ct: int = 70,
38
35
  ratio_threshold_for_lines: float = 3,
39
36
  min_confidence: float = 0.2,
@@ -43,7 +40,6 @@ def estimate_orientation(
43
40
  lines of the document and the assumption that they should be horizontal.
44
41
 
45
42
  Args:
46
- ----
47
43
  img: the img or bitmap to analyze (H, W, C)
48
44
  general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence)
49
45
  estimated by a model
@@ -53,7 +49,6 @@ def estimate_orientation(
53
49
  lower_area: the minimum area of a contour to be considered
54
50
 
55
51
  Returns:
56
- -------
57
52
  the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation)
58
53
  """
59
54
  assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
@@ -64,13 +59,13 @@ def estimate_orientation(
64
59
  gray_img = cv2.medianBlur(gray_img, 5)
65
60
  thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
66
61
  else:
67
- thresh = img.astype(np.uint8) # type: ignore[assignment]
62
+ thresh = img.astype(np.uint8)
68
63
 
69
64
  page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
70
65
  if page_orientation and orientation_confidence >= min_confidence:
71
66
  # We rotate the image to the general orientation which improves the detection
72
67
  # No expand needed bitmap is already padded
73
- thresh = rotate_image(thresh, -page_orientation) # type: ignore
68
+ thresh = rotate_image(thresh, -page_orientation)
74
69
  else: # That's only required if we do not work on the detection models bin map
75
70
  # try to merge words in lines
76
71
  (h, w) = img.shape[:2]
@@ -91,7 +86,7 @@ def estimate_orientation(
91
86
 
92
87
  angles = []
93
88
  for contour in contours[:n_ct]:
94
- _, (w, h), angle = cv2.minAreaRect(contour) # type: ignore[assignment]
89
+ _, (w, h), angle = cv2.minAreaRect(contour)
95
90
  if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
96
91
  angles.append(angle)
97
92
  elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
@@ -119,9 +114,9 @@ def estimate_orientation(
119
114
 
120
115
 
121
116
  def rectify_crops(
122
- crops: List[np.ndarray],
123
- orientations: List[int],
124
- ) -> List[np.ndarray]:
117
+ crops: list[np.ndarray],
118
+ orientations: list[int],
119
+ ) -> list[np.ndarray]:
125
120
  """Rotate each crop of the list according to the predicted orientation:
126
121
  0: already straight, no rotation
127
122
  1: 90 ccw, rotate 3 times ccw
@@ -139,8 +134,8 @@ def rectify_crops(
139
134
 
140
135
  def rectify_loc_preds(
141
136
  page_loc_preds: np.ndarray,
142
- orientations: List[int],
143
- ) -> Optional[np.ndarray]:
137
+ orientations: list[int],
138
+ ) -> np.ndarray | None:
144
139
  """Orient the quadrangle (Polygon4P) according to the predicted orientation,
145
140
  so that the points are in this order: top L, top R, bot R, bot L if the crop is readable
146
141
  """
@@ -157,16 +152,14 @@ def rectify_loc_preds(
157
152
  )
158
153
 
159
154
 
160
- def get_language(text: str) -> Tuple[str, float]:
155
+ def get_language(text: str) -> tuple[str, float]:
161
156
  """Get languages of a text using langdetect model.
162
157
  Get the language with the highest probability or no language if only a few words or a low probability
163
158
 
164
159
  Args:
165
- ----
166
160
  text (str): text
167
161
 
168
162
  Returns:
169
- -------
170
163
  The detected language in ISO 639 code and confidence score
171
164
  """
172
165
  try:
onnxtr/models/builder.py CHANGED
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  from scipy.cluster.hierarchy import fclusterdata
@@ -20,7 +20,6 @@ class DocumentBuilder(NestedObject):
20
20
  """Implements a document builder
21
21
 
22
22
  Args:
23
- ----
24
23
  resolve_lines: whether words should be automatically grouped into lines
25
24
  resolve_blocks: whether lines should be automatically grouped into blocks
26
25
  paragraph_break: relative length of the minimum space separating paragraphs
@@ -41,15 +40,13 @@ class DocumentBuilder(NestedObject):
41
40
  self.export_as_straight_boxes = export_as_straight_boxes
42
41
 
43
42
  @staticmethod
44
- def _sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
43
+ def _sort_boxes(boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
45
44
  """Sort bounding boxes from top to bottom, left to right
46
45
 
47
46
  Args:
48
- ----
49
47
  boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox)
50
48
 
51
49
  Returns:
52
- -------
53
50
  tuple: indices of ordered boxes of shape (N,), boxes
54
51
  If straight boxes are passed tpo the function, boxes are unchanged
55
52
  else: boxes returned are straight boxes fitted to the straightened rotated boxes
@@ -65,16 +62,14 @@ class DocumentBuilder(NestedObject):
65
62
  boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1)
66
63
  return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes
67
64
 
68
- def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[List[int]]:
65
+ def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: list[int]) -> list[list[int]]:
69
66
  """Split a line in sub_lines
70
67
 
71
68
  Args:
72
- ----
73
69
  boxes: bounding boxes of shape (N, 4)
74
70
  word_idcs: list of indexes for the words of the line
75
71
 
76
72
  Returns:
77
- -------
78
73
  A list of (sub-)lines computed from the original line (words)
79
74
  """
80
75
  lines = []
@@ -105,15 +100,13 @@ class DocumentBuilder(NestedObject):
105
100
 
106
101
  return lines
107
102
 
108
- def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]:
103
+ def _resolve_lines(self, boxes: np.ndarray) -> list[list[int]]:
109
104
  """Order boxes to group them in lines
110
105
 
111
106
  Args:
112
- ----
113
107
  boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox
114
108
 
115
109
  Returns:
116
- -------
117
110
  nested list of box indices
118
111
  """
119
112
  # Sort boxes, and straighten the boxes if they are rotated
@@ -153,16 +146,14 @@ class DocumentBuilder(NestedObject):
153
146
  return lines
154
147
 
155
148
  @staticmethod
156
- def _resolve_blocks(boxes: np.ndarray, lines: List[List[int]]) -> List[List[List[int]]]:
149
+ def _resolve_blocks(boxes: np.ndarray, lines: list[list[int]]) -> list[list[list[int]]]:
157
150
  """Order lines to group them in blocks
158
151
 
159
152
  Args:
160
- ----
161
153
  boxes: bounding boxes of shape (N, 4) or (N, 4, 2)
162
154
  lines: list of lines, each line is a list of idx
163
155
 
164
156
  Returns:
165
- -------
166
157
  nested list of box indices
167
158
  """
168
159
  # Resolve enclosing boxes of lines
@@ -207,7 +198,7 @@ class DocumentBuilder(NestedObject):
207
198
  # Compute clusters
208
199
  clusters = fclusterdata(box_features, t=0.1, depth=4, criterion="distance", metric="euclidean")
209
200
 
210
- _blocks: Dict[int, List[int]] = {}
201
+ _blocks: dict[int, list[int]] = {}
211
202
  # Form clusters
212
203
  for line_idx, cluster_idx in enumerate(clusters):
213
204
  if cluster_idx in _blocks.keys():
@@ -224,13 +215,12 @@ class DocumentBuilder(NestedObject):
224
215
  self,
225
216
  boxes: np.ndarray,
226
217
  objectness_scores: np.ndarray,
227
- word_preds: List[Tuple[str, float]],
228
- crop_orientations: List[Dict[str, Any]],
229
- ) -> List[Block]:
218
+ word_preds: list[tuple[str, float]],
219
+ crop_orientations: list[dict[str, Any]],
220
+ ) -> list[Block]:
230
221
  """Gather independent words in structured blocks
231
222
 
232
223
  Args:
233
- ----
234
224
  boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
235
225
  objectness_scores: objectness scores of all detected words of the page, of shape N
236
226
  word_preds: list of all detected words of the page, of shape N
@@ -238,7 +228,6 @@ class DocumentBuilder(NestedObject):
238
228
  the general orientation (orientations + confidences) of the crops
239
229
 
240
230
  Returns:
241
- -------
242
231
  list of block elements
243
232
  """
244
233
  if boxes.shape[0] != len(word_preds):
@@ -295,19 +284,18 @@ class DocumentBuilder(NestedObject):
295
284
 
296
285
  def __call__(
297
286
  self,
298
- pages: List[np.ndarray],
299
- boxes: List[np.ndarray],
300
- objectness_scores: List[np.ndarray],
301
- text_preds: List[List[Tuple[str, float]]],
302
- page_shapes: List[Tuple[int, int]],
303
- crop_orientations: List[Dict[str, Any]],
304
- orientations: Optional[List[Dict[str, Any]]] = None,
305
- languages: Optional[List[Dict[str, Any]]] = None,
287
+ pages: list[np.ndarray],
288
+ boxes: list[np.ndarray],
289
+ objectness_scores: list[np.ndarray],
290
+ text_preds: list[list[tuple[str, float]]],
291
+ page_shapes: list[tuple[int, int]],
292
+ crop_orientations: list[dict[str, Any]],
293
+ orientations: list[dict[str, Any]] | None = None,
294
+ languages: list[dict[str, Any]] | None = None,
306
295
  ) -> Document:
307
296
  """Re-arrange detected words into structured blocks
308
297
 
309
298
  Args:
310
- ----
311
299
  pages: list of N elements, where each element represents the page image
312
300
  boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4)
313
301
  or (*, 4, 2) for all words for a given page
@@ -322,7 +310,6 @@ class DocumentBuilder(NestedObject):
322
310
  where each element is a dictionary containing the language (language + confidence)
323
311
 
324
312
  Returns:
325
- -------
326
313
  document object
327
314
  """
328
315
  if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
@@ -6,7 +6,7 @@
6
6
  # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, Optional
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
 
@@ -18,7 +18,7 @@ __all__ = [
18
18
  "mobilenet_v3_small_page_orientation",
19
19
  ]
20
20
 
21
- default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ default_cfgs: dict[str, dict[str, Any]] = {
22
22
  "mobilenet_v3_small_crop_orientation": {
23
23
  "mean": (0.694, 0.695, 0.693),
24
24
  "std": (0.299, 0.296, 0.301),
@@ -42,7 +42,6 @@ class MobileNetV3(Engine):
42
42
  """MobileNetV3 Onnx loader
43
43
 
44
44
  Args:
45
- ----
46
45
  model_path: path or url to onnx model file
47
46
  engine_cfg: configuration for the inference engine
48
47
  cfg: configuration dictionary
@@ -52,8 +51,8 @@ class MobileNetV3(Engine):
52
51
  def __init__(
53
52
  self,
54
53
  model_path: str,
55
- engine_cfg: Optional[EngineConfig] = None,
56
- cfg: Optional[Dict[str, Any]] = None,
54
+ engine_cfg: EngineConfig | None = None,
55
+ cfg: dict[str, Any] | None = None,
57
56
  **kwargs: Any,
58
57
  ) -> None:
59
58
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
@@ -71,7 +70,7 @@ def _mobilenet_v3(
71
70
  arch: str,
72
71
  model_path: str,
73
72
  load_in_8_bit: bool = False,
74
- engine_cfg: Optional[EngineConfig] = None,
73
+ engine_cfg: EngineConfig | None = None,
75
74
  **kwargs: Any,
76
75
  ) -> MobileNetV3:
77
76
  # Patch the url
@@ -83,7 +82,7 @@ def _mobilenet_v3(
83
82
  def mobilenet_v3_small_crop_orientation(
84
83
  model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
85
84
  load_in_8_bit: bool = False,
86
- engine_cfg: Optional[EngineConfig] = None,
85
+ engine_cfg: EngineConfig | None = None,
87
86
  **kwargs: Any,
88
87
  ) -> MobileNetV3:
89
88
  """MobileNetV3-Small architecture as described in
@@ -97,14 +96,12 @@ def mobilenet_v3_small_crop_orientation(
97
96
  >>> out = model(input_tensor)
98
97
 
99
98
  Args:
100
- ----
101
99
  model_path: path to onnx model file, defaults to url in default_cfgs
102
100
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
103
101
  engine_cfg: configuration for the inference engine
104
102
  **kwargs: keyword arguments of the MobileNetV3 architecture
105
103
 
106
104
  Returns:
107
- -------
108
105
  MobileNetV3
109
106
  """
110
107
  return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -113,7 +110,7 @@ def mobilenet_v3_small_crop_orientation(
113
110
  def mobilenet_v3_small_page_orientation(
114
111
  model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
115
112
  load_in_8_bit: bool = False,
116
- engine_cfg: Optional[EngineConfig] = None,
113
+ engine_cfg: EngineConfig | None = None,
117
114
  **kwargs: Any,
118
115
  ) -> MobileNetV3:
119
116
  """MobileNetV3-Small architecture as described in
@@ -127,14 +124,12 @@ def mobilenet_v3_small_page_orientation(
127
124
  >>> out = model(input_tensor)
128
125
 
129
126
  Args:
130
- ----
131
127
  model_path: path to onnx model file, defaults to url in default_cfgs
132
128
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
133
129
  engine_cfg: configuration for the inference engine
134
130
  **kwargs: keyword arguments of the MobileNetV3 architecture
135
131
 
136
132
  Returns:
137
- -------
138
133
  MobileNetV3
139
134
  """
140
135
  return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Optional, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  from scipy.special import softmax
@@ -19,26 +19,25 @@ class OrientationPredictor(NestedObject):
19
19
  4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
20
20
 
21
21
  Args:
22
- ----
23
22
  pre_processor: transform inputs for easier batched model inference
24
23
  model: core classification architecture (backbone + classification head)
25
24
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
26
25
  """
27
26
 
28
- _children_names: List[str] = ["pre_processor", "model"]
27
+ _children_names: list[str] = ["pre_processor", "model"]
29
28
 
30
29
  def __init__(
31
30
  self,
32
- pre_processor: Optional[PreProcessor],
33
- model: Optional[Any],
31
+ pre_processor: PreProcessor | None,
32
+ model: Any | None,
34
33
  ) -> None:
35
34
  self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
36
35
  self.model = model
37
36
 
38
37
  def __call__(
39
38
  self,
40
- inputs: List[np.ndarray],
41
- ) -> List[Union[List[int], List[float]]]:
39
+ inputs: list[np.ndarray],
40
+ ) -> list[list[int] | list[float]]:
42
41
  # Dimension check
43
42
  if any(input.ndim != 3 for input in inputs):
44
43
  raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Optional
6
+ from typing import Any
7
7
 
8
8
  from onnxtr.models.engine import EngineConfig
9
9
 
@@ -13,14 +13,14 @@ from .predictor import OrientationPredictor
13
13
 
14
14
  __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
15
15
 
16
- ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
16
+ ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
17
17
 
18
18
 
19
19
  def _orientation_predictor(
20
20
  arch: Any,
21
21
  model_type: str,
22
22
  load_in_8_bit: bool = False,
23
- engine_cfg: Optional[EngineConfig] = None,
23
+ engine_cfg: EngineConfig | None = None,
24
24
  disabled: bool = False,
25
25
  **kwargs: Any,
26
26
  ) -> OrientationPredictor:
@@ -51,8 +51,9 @@ def _orientation_predictor(
51
51
 
52
52
  def crop_orientation_predictor(
53
53
  arch: Any = "mobilenet_v3_small_crop_orientation",
54
+ batch_size: int = 512,
54
55
  load_in_8_bit: bool = False,
55
- engine_cfg: Optional[EngineConfig] = None,
56
+ engine_cfg: EngineConfig | None = None,
56
57
  **kwargs: Any,
57
58
  ) -> OrientationPredictor:
58
59
  """Crop orientation classification architecture.
@@ -64,24 +65,31 @@ def crop_orientation_predictor(
64
65
  >>> out = model([input_crop])
65
66
 
66
67
  Args:
67
- ----
68
68
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
69
+ batch_size: number of samples the model processes in parallel
69
70
  load_in_8_bit: load the 8-bit quantized version of the model
70
71
  engine_cfg: configuration of inference engine
71
72
  **kwargs: keyword arguments to be passed to the OrientationPredictor
72
73
 
73
74
  Returns:
74
- -------
75
75
  OrientationPredictor
76
76
  """
77
77
  model_type = "crop"
78
- return _orientation_predictor(arch, model_type, load_in_8_bit, engine_cfg, **kwargs)
78
+ return _orientation_predictor(
79
+ arch=arch,
80
+ batch_size=batch_size,
81
+ model_type=model_type,
82
+ load_in_8_bit=load_in_8_bit,
83
+ engine_cfg=engine_cfg,
84
+ **kwargs,
85
+ )
79
86
 
80
87
 
81
88
  def page_orientation_predictor(
82
89
  arch: Any = "mobilenet_v3_small_page_orientation",
90
+ batch_size: int = 2,
83
91
  load_in_8_bit: bool = False,
84
- engine_cfg: Optional[EngineConfig] = None,
92
+ engine_cfg: EngineConfig | None = None,
85
93
  **kwargs: Any,
86
94
  ) -> OrientationPredictor:
87
95
  """Page orientation classification architecture.
@@ -93,15 +101,21 @@ def page_orientation_predictor(
93
101
  >>> out = model([input_page])
94
102
 
95
103
  Args:
96
- ----
97
104
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
105
+ batch_size: number of samples the model processes in parallel
98
106
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
99
107
  engine_cfg: configuration for the inference engine
100
108
  **kwargs: keyword arguments to be passed to the OrientationPredictor
101
109
 
102
110
  Returns:
103
- -------
104
111
  OrientationPredictor
105
112
  """
106
113
  model_type = "page"
107
- return _orientation_predictor(arch, model_type, load_in_8_bit, engine_cfg, **kwargs)
114
+ return _orientation_predictor(
115
+ arch=arch,
116
+ batch_size=batch_size,
117
+ model_type=model_type,
118
+ load_in_8_bit=load_in_8_bit,
119
+ engine_cfg=engine_cfg,
120
+ **kwargs,
121
+ )
@@ -3,7 +3,6 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List
7
6
 
8
7
  import numpy as np
9
8
 
@@ -11,17 +10,15 @@ __all__ = ["_remove_padding"]
11
10
 
12
11
 
13
12
  def _remove_padding(
14
- pages: List[np.ndarray],
15
- loc_preds: List[np.ndarray],
13
+ pages: list[np.ndarray],
14
+ loc_preds: list[np.ndarray],
16
15
  preserve_aspect_ratio: bool,
17
16
  symmetric_pad: bool,
18
17
  assume_straight_pages: bool,
19
- ) -> List[np.ndarray]:
18
+ ) -> list[np.ndarray]:
20
19
  """Remove padding from the localization predictions
21
20
 
22
21
  Args:
23
- ----
24
-
25
22
  pages: list of pages
26
23
  loc_preds: list of localization predictions
27
24
  preserve_aspect_ratio: whether the aspect ratio was preserved during padding
@@ -29,7 +26,6 @@ def _remove_padding(
29
26
  assume_straight_pages: whether the pages are assumed to be straight
30
27
 
31
28
  Returns:
32
- -------
33
29
  list of unpaded localization predictions
34
30
  """
35
31
  if preserve_aspect_ratio:
@@ -3,7 +3,6 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List
7
6
 
8
7
  import cv2
9
8
  import numpy as np
@@ -17,7 +16,6 @@ class DetectionPostProcessor(NestedObject):
17
16
  """Abstract class to postprocess the raw output of the model
18
17
 
19
18
  Args:
20
- ----
21
19
  box_thresh (float): minimal objectness score to consider a box
22
20
  bin_thresh (float): threshold to apply to segmentation raw heatmap
23
21
  assume straight_pages (bool): if True, fit straight boxes only
@@ -37,13 +35,11 @@ class DetectionPostProcessor(NestedObject):
37
35
  """Compute the confidence score for a polygon : mean of the p values on the polygon
38
36
 
39
37
  Args:
40
- ----
41
38
  pred (np.ndarray): p map returned by the model
42
39
  points: coordinates of the polygon
43
40
  assume_straight_pages: if True, fit straight boxes only
44
41
 
45
42
  Returns:
46
- -------
47
43
  polygon objectness
48
44
  """
49
45
  h, w = pred.shape[:2]
@@ -71,17 +67,15 @@ class DetectionPostProcessor(NestedObject):
71
67
  def __call__(
72
68
  self,
73
69
  proba_map,
74
- ) -> List[List[np.ndarray]]:
70
+ ) -> list[list[np.ndarray]]:
75
71
  """Performs postprocessing for a list of model outputs
76
72
 
77
73
  Args:
78
- ----
79
74
  proba_map: probability map of shape (N, H, W, C)
80
75
 
81
76
  Returns:
82
- -------
83
77
  list of N class predictions (for each input sample), where each class predictions is a list of C tensors
84
- of shape (*, 5) or (*, 6)
78
+ of shape (*, 5) or (*, 6)
85
79
  """
86
80
  if proba_map.ndim != 4:
87
81
  raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.")