onnxtr 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. onnxtr/io/elements.py +17 -4
  2. onnxtr/io/pdf.py +6 -3
  3. onnxtr/models/__init__.py +1 -0
  4. onnxtr/models/_utils.py +57 -20
  5. onnxtr/models/builder.py +24 -9
  6. onnxtr/models/classification/models/mobilenet.py +12 -5
  7. onnxtr/models/classification/zoo.py +18 -6
  8. onnxtr/models/detection/_utils/__init__.py +1 -0
  9. onnxtr/models/detection/_utils/base.py +66 -0
  10. onnxtr/models/detection/models/differentiable_binarization.py +27 -12
  11. onnxtr/models/detection/models/fast.py +30 -9
  12. onnxtr/models/detection/models/linknet.py +24 -9
  13. onnxtr/models/detection/postprocessor/base.py +4 -3
  14. onnxtr/models/detection/predictor/base.py +15 -1
  15. onnxtr/models/detection/zoo.py +12 -3
  16. onnxtr/models/engine.py +73 -7
  17. onnxtr/models/predictor/base.py +65 -42
  18. onnxtr/models/predictor/predictor.py +22 -15
  19. onnxtr/models/recognition/models/crnn.py +24 -9
  20. onnxtr/models/recognition/models/master.py +14 -5
  21. onnxtr/models/recognition/models/parseq.py +14 -5
  22. onnxtr/models/recognition/models/sar.py +12 -5
  23. onnxtr/models/recognition/models/vitstr.py +18 -7
  24. onnxtr/models/recognition/zoo.py +9 -6
  25. onnxtr/models/zoo.py +16 -0
  26. onnxtr/py.typed +0 -0
  27. onnxtr/utils/geometry.py +33 -12
  28. onnxtr/version.py +1 -1
  29. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/METADATA +60 -21
  30. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/RECORD +34 -31
  31. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/WHEEL +1 -1
  32. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/top_level.txt +0 -1
  33. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/LICENSE +0 -0
  34. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/zip-safe +0 -0
onnxtr/io/elements.py CHANGED
@@ -67,10 +67,11 @@ class Word(Element):
67
67
  confidence: the confidence associated with the text prediction
68
68
  geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
69
69
  the page's size
70
+ objectness_score: the objectness score of the detection
70
71
  crop_orientation: the general orientation of the crop in degrees and its confidence
71
72
  """
72
73
 
73
- _exported_keys: List[str] = ["value", "confidence", "geometry", "crop_orientation"]
74
+ _exported_keys: List[str] = ["value", "confidence", "geometry", "objectness_score", "crop_orientation"]
74
75
  _children_names: List[str] = []
75
76
 
76
77
  def __init__(
@@ -78,12 +79,14 @@ class Word(Element):
78
79
  value: str,
79
80
  confidence: float,
80
81
  geometry: Union[BoundingBox, np.ndarray],
82
+ objectness_score: float,
81
83
  crop_orientation: Dict[str, Any],
82
84
  ) -> None:
83
85
  super().__init__()
84
86
  self.value = value
85
87
  self.confidence = confidence
86
88
  self.geometry = geometry
89
+ self.objectness_score = objectness_score
87
90
  self.crop_orientation = crop_orientation
88
91
 
89
92
  def render(self) -> str:
@@ -143,7 +146,7 @@ class Line(Element):
143
146
  all words in it.
144
147
  """
145
148
 
146
- _exported_keys: List[str] = ["geometry"]
149
+ _exported_keys: List[str] = ["geometry", "objectness_score"]
147
150
  _children_names: List[str] = ["words"]
148
151
  words: List[Word] = []
149
152
 
@@ -151,7 +154,11 @@ class Line(Element):
151
154
  self,
152
155
  words: List[Word],
153
156
  geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
157
+ objectness_score: Optional[float] = None,
154
158
  ) -> None:
159
+ # Compute the objectness score of the line
160
+ if objectness_score is None:
161
+ objectness_score = float(np.mean([w.objectness_score for w in words]))
155
162
  # Resolve the geometry using the smallest enclosing bounding box
156
163
  if geometry is None:
157
164
  # Check whether this is a rotated or straight box
@@ -160,6 +167,7 @@ class Line(Element):
160
167
 
161
168
  super().__init__(words=words)
162
169
  self.geometry = geometry
170
+ self.objectness_score = objectness_score
163
171
 
164
172
  def render(self) -> str:
165
173
  """Renders the full text of the element"""
@@ -186,7 +194,7 @@ class Block(Element):
186
194
  all lines and artefacts in it.
187
195
  """
188
196
 
189
- _exported_keys: List[str] = ["geometry"]
197
+ _exported_keys: List[str] = ["geometry", "objectness_score"]
190
198
  _children_names: List[str] = ["lines", "artefacts"]
191
199
  lines: List[Line] = []
192
200
  artefacts: List[Artefact] = []
@@ -196,7 +204,11 @@ class Block(Element):
196
204
  lines: List[Line] = [],
197
205
  artefacts: List[Artefact] = [],
198
206
  geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
207
+ objectness_score: Optional[float] = None,
199
208
  ) -> None:
209
+ # Compute the objectness score of the line
210
+ if objectness_score is None:
211
+ objectness_score = float(np.mean([w.objectness_score for line in lines for w in line.words]))
200
212
  # Resolve the geometry using the smallest enclosing bounding box
201
213
  if geometry is None:
202
214
  line_boxes = [word.geometry for line in lines for word in line.words]
@@ -208,6 +220,7 @@ class Block(Element):
208
220
 
209
221
  super().__init__(lines=lines, artefacts=artefacts)
210
222
  self.geometry = geometry
223
+ self.objectness_score = objectness_score
211
224
 
212
225
  def render(self, line_break: str = "\n") -> str:
213
226
  """Renders the full text of the element"""
@@ -314,7 +327,7 @@ class Page(Element):
314
327
  SubElement(
315
328
  head,
316
329
  "meta",
317
- attrib={"name": "ocr-system", "content": f" {onnxtr.__version__}"}, # type: ignore[attr-defined]
330
+ attrib={"name": "ocr-system", "content": f"onnxtr {onnxtr.__version__}"}, # type: ignore[attr-defined]
318
331
  )
319
332
  SubElement(
320
333
  head,
onnxtr/io/pdf.py CHANGED
@@ -15,7 +15,7 @@ __all__ = ["read_pdf"]
15
15
 
16
16
  def read_pdf(
17
17
  file: AbstractFile,
18
- scale: float = 2,
18
+ scale: int = 2,
19
19
  rgb_mode: bool = True,
20
20
  password: Optional[str] = None,
21
21
  **kwargs: Any,
@@ -38,5 +38,8 @@ def read_pdf(
38
38
  the list of pages decoded as numpy ndarray of shape H x W x C
39
39
  """
40
40
  # Rasterise pages to numpy ndarrays with pypdfium2
41
- pdf = pdfium.PdfDocument(file, password=password, autoclose=True)
42
- return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
41
+ pdf = pdfium.PdfDocument(file, password=password)
42
+ try:
43
+ return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
44
+ finally:
45
+ pdf.close()
onnxtr/models/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
+ from .engine import EngineConfig
1
2
  from .classification import *
2
3
  from .detection import *
3
4
  from .recognition import *
onnxtr/models/_utils.py CHANGED
@@ -11,6 +11,8 @@ import cv2
11
11
  import numpy as np
12
12
  from langdetect import LangDetectException, detect_langs
13
13
 
14
+ from onnxtr.utils.geometry import rotate_image
15
+
14
16
  __all__ = ["estimate_orientation", "get_language"]
15
17
 
16
18
 
@@ -29,56 +31,91 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
29
31
  return max(w / h, h / w)
30
32
 
31
33
 
32
- def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int:
34
+ def estimate_orientation(
35
+ img: np.ndarray,
36
+ general_page_orientation: Optional[Tuple[int, float]] = None,
37
+ n_ct: int = 70,
38
+ ratio_threshold_for_lines: float = 3,
39
+ min_confidence: float = 0.2,
40
+ lower_area: int = 100,
41
+ ) -> int:
33
42
  """Estimate the angle of the general document orientation based on the
34
43
  lines of the document and the assumption that they should be horizontal.
35
44
 
36
45
  Args:
37
46
  ----
38
47
  img: the img or bitmap to analyze (H, W, C)
48
+ general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence)
49
+ estimated by a model
39
50
  n_ct: the number of contours used for the orientation estimation
40
51
  ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
52
+ min_confidence: the minimum confidence to consider the general_page_orientation
53
+ lower_area: the minimum area of a contour to be considered
41
54
 
42
55
  Returns:
43
56
  -------
44
- the angle of the general document orientation
57
+ the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation)
45
58
  """
46
59
  assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
47
- max_value = np.max(img)
48
- min_value = np.min(img)
49
- if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1):
50
- thresh = img.astype(np.uint8)
51
- if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
60
+ thresh = None
61
+ # Convert image to grayscale if necessary
62
+ if img.shape[-1] == 3:
52
63
  gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
53
64
  gray_img = cv2.medianBlur(gray_img, 5)
54
- thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # type: ignore[assignment]
55
-
56
- # try to merge words in lines
57
- (h, w) = img.shape[:2]
58
- k_x = max(1, (floor(w / 100)))
59
- k_y = max(1, (floor(h / 100)))
60
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
61
- thresh = cv2.dilate(thresh, kernel, iterations=1) # type: ignore[assignment]
65
+ thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
66
+ else:
67
+ thresh = img.astype(np.uint8) # type: ignore[assignment]
68
+
69
+ page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
70
+ if page_orientation and orientation_confidence >= min_confidence:
71
+ # We rotate the image to the general orientation which improves the detection
72
+ # No expand needed bitmap is already padded
73
+ thresh = rotate_image(thresh, -page_orientation) # type: ignore
74
+ else: # That's only required if we do not work on the detection models bin map
75
+ # try to merge words in lines
76
+ (h, w) = img.shape[:2]
77
+ k_x = max(1, (floor(w / 100)))
78
+ k_y = max(1, (floor(h / 100)))
79
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
80
+ thresh = cv2.dilate(thresh, kernel, iterations=1)
62
81
 
63
82
  # extract contours
64
83
  contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
65
84
 
66
- # Sort contours
67
- contours = sorted(contours, key=get_max_width_length_ratio, reverse=True)
85
+ # Filter & Sort contours
86
+ contours = sorted(
87
+ [contour for contour in contours if cv2.contourArea(contour) > lower_area],
88
+ key=get_max_width_length_ratio,
89
+ reverse=True,
90
+ )
68
91
 
69
92
  angles = []
70
93
  for contour in contours[:n_ct]:
71
- _, (w, h), angle = cv2.minAreaRect(contour)
94
+ _, (w, h), angle = cv2.minAreaRect(contour) # type: ignore[assignment]
72
95
  if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
73
96
  angles.append(angle)
74
97
  elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
75
98
  angles.append(angle - 90)
76
99
 
77
100
  if len(angles) == 0:
78
- return 0 # in case no angles is found
101
+ estimated_angle = 0 # in case no angles is found
79
102
  else:
80
103
  median = -median_low(angles)
81
- return round(median) if abs(median) != 0 else 0
104
+ estimated_angle = -round(median) if abs(median) != 0 else 0
105
+
106
+ # combine with the general orientation and the estimated angle
107
+ if page_orientation and orientation_confidence >= min_confidence:
108
+ # special case where the estimated angle is mostly wrong:
109
+ # case 1: - and + swapped
110
+ # case 2: estimated angle is completely wrong
111
+ # so in this case we prefer the general page orientation
112
+ if abs(estimated_angle) == abs(page_orientation):
113
+ return page_orientation
114
+ estimated_angle = estimated_angle if page_orientation == 0 else page_orientation + estimated_angle
115
+ if estimated_angle > 180:
116
+ estimated_angle -= 360
117
+
118
+ return estimated_angle # return the clockwise angle (negative - left side rotation, positive - right side rotation)
82
119
 
83
120
 
84
121
  def rectify_crops(
onnxtr/models/builder.py CHANGED
@@ -31,7 +31,7 @@ class DocumentBuilder(NestedObject):
31
31
  def __init__(
32
32
  self,
33
33
  resolve_lines: bool = True,
34
- resolve_blocks: bool = True,
34
+ resolve_blocks: bool = False,
35
35
  paragraph_break: float = 0.035,
36
36
  export_as_straight_boxes: bool = False,
37
37
  ) -> None:
@@ -223,6 +223,7 @@ class DocumentBuilder(NestedObject):
223
223
  def _build_blocks(
224
224
  self,
225
225
  boxes: np.ndarray,
226
+ objectness_scores: np.ndarray,
226
227
  word_preds: List[Tuple[str, float]],
227
228
  crop_orientations: List[Dict[str, Any]],
228
229
  ) -> List[Block]:
@@ -230,7 +231,8 @@ class DocumentBuilder(NestedObject):
230
231
 
231
232
  Args:
232
233
  ----
233
- boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
234
+ boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
235
+ objectness_scores: objectness scores of all detected words of the page, of shape N
234
236
  word_preds: list of all detected words of the page, of shape N
235
237
  crop_orientations: list of dictoinaries containing
236
238
  the general orientation (orientations + confidences) of the crops
@@ -265,12 +267,14 @@ class DocumentBuilder(NestedObject):
265
267
  Word(
266
268
  *word_preds[idx],
267
269
  tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
270
+ float(objectness_scores[idx]),
268
271
  crop_orientations[idx],
269
272
  )
270
273
  if boxes.ndim == 3
271
274
  else Word(
272
275
  *word_preds[idx],
273
276
  ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
277
+ float(objectness_scores[idx]),
274
278
  crop_orientations[idx],
275
279
  )
276
280
  for idx in line
@@ -293,6 +297,7 @@ class DocumentBuilder(NestedObject):
293
297
  self,
294
298
  pages: List[np.ndarray],
295
299
  boxes: List[np.ndarray],
300
+ objectness_scores: List[np.ndarray],
296
301
  text_preds: List[List[Tuple[str, float]]],
297
302
  page_shapes: List[Tuple[int, int]],
298
303
  crop_orientations: List[Dict[str, Any]],
@@ -304,8 +309,9 @@ class DocumentBuilder(NestedObject):
304
309
  Args:
305
310
  ----
306
311
  pages: list of N elements, where each element represents the page image
307
- boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
308
- or (*, 6) for all words for a given page
312
+ boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4)
313
+ or (*, 4, 2) for all words for a given page
314
+ objectness_scores: list of N elements, where each element represents the objectness scores
309
315
  text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
310
316
  page_shapes: shape of each page, of size N
311
317
  crop_orientations: list of N elements, where each element is
@@ -319,9 +325,9 @@ class DocumentBuilder(NestedObject):
319
325
  -------
320
326
  document object
321
327
  """
322
- if len(boxes) != len(text_preds) != len(crop_orientations) or len(boxes) != len(page_shapes) != len(
323
- crop_orientations
324
- ):
328
+ if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
329
+ page_shapes
330
+ ) != len(crop_orientations) != len(objectness_scores):
325
331
  raise ValueError("All arguments are expected to be lists of the same size")
326
332
 
327
333
  _orientations = (
@@ -339,6 +345,7 @@ class DocumentBuilder(NestedObject):
339
345
  page,
340
346
  self._build_blocks(
341
347
  page_boxes,
348
+ loc_scores,
342
349
  word_preds,
343
350
  word_crop_orientations,
344
351
  ),
@@ -347,8 +354,16 @@ class DocumentBuilder(NestedObject):
347
354
  orientation,
348
355
  language,
349
356
  )
350
- for page, _idx, shape, page_boxes, word_preds, word_crop_orientations, orientation, language in zip(
351
- pages, range(len(boxes)), page_shapes, boxes, text_preds, crop_orientations, _orientations, _languages
357
+ for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501
358
+ pages,
359
+ range(len(boxes)),
360
+ page_shapes,
361
+ boxes,
362
+ objectness_scores,
363
+ text_preds,
364
+ crop_orientations,
365
+ _orientations,
366
+ _languages,
352
367
  )
353
368
  ]
354
369
 
@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional
10
10
 
11
11
  import numpy as np
12
12
 
13
- from ...engine import Engine
13
+ from ...engine import Engine, EngineConfig
14
14
 
15
15
  __all__ = [
16
16
  "mobilenet_v3_small_crop_orientation",
@@ -43,6 +43,7 @@ class MobileNetV3(Engine):
43
43
  Args:
44
44
  ----
45
45
  model_path: path or url to onnx model file
46
+ engine_cfg: configuration for the inference engine
46
47
  cfg: configuration dictionary
47
48
  **kwargs: additional arguments to be passed to `Engine`
48
49
  """
@@ -50,10 +51,11 @@ class MobileNetV3(Engine):
50
51
  def __init__(
51
52
  self,
52
53
  model_path: str,
54
+ engine_cfg: EngineConfig = EngineConfig(),
53
55
  cfg: Optional[Dict[str, Any]] = None,
54
56
  **kwargs: Any,
55
57
  ) -> None:
56
- super().__init__(url=model_path, **kwargs)
58
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
57
59
  self.cfg = cfg
58
60
 
59
61
  def __call__(
@@ -67,17 +69,19 @@ def _mobilenet_v3(
67
69
  arch: str,
68
70
  model_path: str,
69
71
  load_in_8_bit: bool = False,
72
+ engine_cfg: EngineConfig = EngineConfig(),
70
73
  **kwargs: Any,
71
74
  ) -> MobileNetV3:
72
75
  # Patch the url
73
76
  model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
74
77
  _cfg = deepcopy(default_cfgs[arch])
75
- return MobileNetV3(model_path, cfg=_cfg, **kwargs)
78
+ return MobileNetV3(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
76
79
 
77
80
 
78
81
  def mobilenet_v3_small_crop_orientation(
79
82
  model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
80
83
  load_in_8_bit: bool = False,
84
+ engine_cfg: EngineConfig = EngineConfig(),
81
85
  **kwargs: Any,
82
86
  ) -> MobileNetV3:
83
87
  """MobileNetV3-Small architecture as described in
@@ -94,18 +98,20 @@ def mobilenet_v3_small_crop_orientation(
94
98
  ----
95
99
  model_path: path to onnx model file, defaults to url in default_cfgs
96
100
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
101
+ engine_cfg: configuration for the inference engine
97
102
  **kwargs: keyword arguments of the MobileNetV3 architecture
98
103
 
99
104
  Returns:
100
105
  -------
101
106
  MobileNetV3
102
107
  """
103
- return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, **kwargs)
108
+ return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
104
109
 
105
110
 
106
111
  def mobilenet_v3_small_page_orientation(
107
112
  model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
108
113
  load_in_8_bit: bool = False,
114
+ engine_cfg: EngineConfig = EngineConfig(),
109
115
  **kwargs: Any,
110
116
  ) -> MobileNetV3:
111
117
  """MobileNetV3-Small architecture as described in
@@ -122,10 +128,11 @@ def mobilenet_v3_small_page_orientation(
122
128
  ----
123
129
  model_path: path to onnx model file, defaults to url in default_cfgs
124
130
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
131
+ engine_cfg: configuration for the inference engine
125
132
  **kwargs: keyword arguments of the MobileNetV3 architecture
126
133
 
127
134
  Returns:
128
135
  -------
129
136
  MobileNetV3
130
137
  """
131
- return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, **kwargs)
138
+ return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -5,6 +5,8 @@
5
5
 
6
6
  from typing import Any, List
7
7
 
8
+ from onnxtr.models.engine import EngineConfig
9
+
8
10
  from .. import classification
9
11
  from ..preprocessor import PreProcessor
10
12
  from .predictor import OrientationPredictor
@@ -14,12 +16,14 @@ __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
14
16
  ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
15
17
 
16
18
 
17
- def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any) -> OrientationPredictor:
19
+ def _orientation_predictor(
20
+ arch: str, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any
21
+ ) -> OrientationPredictor:
18
22
  if arch not in ORIENTATION_ARCHS:
19
23
  raise ValueError(f"unknown architecture '{arch}'")
20
24
 
21
25
  # Load directly classifier from backbone
22
- _model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit)
26
+ _model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg)
23
27
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
24
28
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
25
29
  kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
@@ -32,7 +36,10 @@ def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any
32
36
 
33
37
 
34
38
  def crop_orientation_predictor(
35
- arch: Any = "mobilenet_v3_small_crop_orientation", load_in_8_bit: bool = False, **kwargs: Any
39
+ arch: Any = "mobilenet_v3_small_crop_orientation",
40
+ load_in_8_bit: bool = False,
41
+ engine_cfg: EngineConfig = EngineConfig(),
42
+ **kwargs: Any,
36
43
  ) -> OrientationPredictor:
37
44
  """Crop orientation classification architecture.
38
45
 
@@ -46,17 +53,21 @@ def crop_orientation_predictor(
46
53
  ----
47
54
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
48
55
  load_in_8_bit: load the 8-bit quantized version of the model
56
+ engine_cfg: configuration of inference engine
49
57
  **kwargs: keyword arguments to be passed to the OrientationPredictor
50
58
 
51
59
  Returns:
52
60
  -------
53
61
  OrientationPredictor
54
62
  """
55
- return _orientation_predictor(arch, load_in_8_bit, **kwargs)
63
+ return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
56
64
 
57
65
 
58
66
  def page_orientation_predictor(
59
- arch: Any = "mobilenet_v3_small_page_orientation", load_in_8_bit: bool = False, **kwargs: Any
67
+ arch: Any = "mobilenet_v3_small_page_orientation",
68
+ load_in_8_bit: bool = False,
69
+ engine_cfg: EngineConfig = EngineConfig(),
70
+ **kwargs: Any,
60
71
  ) -> OrientationPredictor:
61
72
  """Page orientation classification architecture.
62
73
 
@@ -70,10 +81,11 @@ def page_orientation_predictor(
70
81
  ----
71
82
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
72
83
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
84
+ engine_cfg: configuration for the inference engine
73
85
  **kwargs: keyword arguments to be passed to the OrientationPredictor
74
86
 
75
87
  Returns:
76
88
  -------
77
89
  OrientationPredictor
78
90
  """
79
- return _orientation_predictor(arch, load_in_8_bit, **kwargs)
91
+ return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
@@ -0,0 +1 @@
1
+ from . base import *
@@ -0,0 +1,66 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import List
7
+
8
+ import numpy as np
9
+
10
+ __all__ = ["_remove_padding"]
11
+
12
+
13
+ def _remove_padding(
14
+ pages: List[np.ndarray],
15
+ loc_preds: List[np.ndarray],
16
+ preserve_aspect_ratio: bool,
17
+ symmetric_pad: bool,
18
+ assume_straight_pages: bool,
19
+ ) -> List[np.ndarray]:
20
+ """Remove padding from the localization predictions
21
+
22
+ Args:
23
+ ----
24
+
25
+ pages: list of pages
26
+ loc_preds: list of localization predictions
27
+ preserve_aspect_ratio: whether the aspect ratio was preserved during padding
28
+ symmetric_pad: whether the padding was symmetric
29
+ assume_straight_pages: whether the pages are assumed to be straight
30
+
31
+ Returns:
32
+ -------
33
+ list of unpaded localization predictions
34
+ """
35
+ if preserve_aspect_ratio:
36
+ # Rectify loc_preds to remove padding
37
+ rectified_preds = []
38
+ for page, loc_pred in zip(pages, loc_preds):
39
+ h, w = page.shape[0], page.shape[1]
40
+ if h > w:
41
+ # y unchanged, dilate x coord
42
+ if symmetric_pad:
43
+ if assume_straight_pages:
44
+ loc_pred[:, [0, 2]] = (loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5
45
+ else:
46
+ loc_pred[:, :, 0] = (loc_pred[:, :, 0] - 0.5) * h / w + 0.5
47
+ else:
48
+ if assume_straight_pages:
49
+ loc_pred[:, [0, 2]] *= h / w
50
+ else:
51
+ loc_pred[:, :, 0] *= h / w
52
+ elif w > h:
53
+ # x unchanged, dilate y coord
54
+ if symmetric_pad:
55
+ if assume_straight_pages:
56
+ loc_pred[:, [1, 3]] = (loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5
57
+ else:
58
+ loc_pred[:, :, 1] = (loc_pred[:, :, 1] - 0.5) * w / h + 0.5
59
+ else:
60
+ if assume_straight_pages:
61
+ loc_pred[:, [1, 3]] *= w / h
62
+ else:
63
+ loc_pred[:, :, 1] *= w / h
64
+ rectified_preds.append(np.clip(loc_pred, 0, 1))
65
+ return rectified_preds
66
+ return loc_preds