onnxtr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. onnxtr/__init__.py +2 -0
  2. onnxtr/contrib/__init__.py +0 -0
  3. onnxtr/contrib/artefacts.py +131 -0
  4. onnxtr/contrib/base.py +105 -0
  5. onnxtr/file_utils.py +33 -0
  6. onnxtr/io/__init__.py +5 -0
  7. onnxtr/io/elements.py +455 -0
  8. onnxtr/io/html.py +28 -0
  9. onnxtr/io/image.py +56 -0
  10. onnxtr/io/pdf.py +42 -0
  11. onnxtr/io/reader.py +85 -0
  12. onnxtr/models/__init__.py +4 -0
  13. onnxtr/models/_utils.py +141 -0
  14. onnxtr/models/builder.py +355 -0
  15. onnxtr/models/classification/__init__.py +2 -0
  16. onnxtr/models/classification/models/__init__.py +1 -0
  17. onnxtr/models/classification/models/mobilenet.py +120 -0
  18. onnxtr/models/classification/predictor/__init__.py +1 -0
  19. onnxtr/models/classification/predictor/base.py +57 -0
  20. onnxtr/models/classification/zoo.py +76 -0
  21. onnxtr/models/detection/__init__.py +2 -0
  22. onnxtr/models/detection/core.py +101 -0
  23. onnxtr/models/detection/models/__init__.py +3 -0
  24. onnxtr/models/detection/models/differentiable_binarization.py +159 -0
  25. onnxtr/models/detection/models/fast.py +160 -0
  26. onnxtr/models/detection/models/linknet.py +160 -0
  27. onnxtr/models/detection/postprocessor/__init__.py +0 -0
  28. onnxtr/models/detection/postprocessor/base.py +144 -0
  29. onnxtr/models/detection/predictor/__init__.py +1 -0
  30. onnxtr/models/detection/predictor/base.py +54 -0
  31. onnxtr/models/detection/zoo.py +73 -0
  32. onnxtr/models/engine.py +50 -0
  33. onnxtr/models/predictor/__init__.py +1 -0
  34. onnxtr/models/predictor/base.py +175 -0
  35. onnxtr/models/predictor/predictor.py +145 -0
  36. onnxtr/models/preprocessor/__init__.py +1 -0
  37. onnxtr/models/preprocessor/base.py +118 -0
  38. onnxtr/models/recognition/__init__.py +2 -0
  39. onnxtr/models/recognition/core.py +28 -0
  40. onnxtr/models/recognition/models/__init__.py +5 -0
  41. onnxtr/models/recognition/models/crnn.py +226 -0
  42. onnxtr/models/recognition/models/master.py +145 -0
  43. onnxtr/models/recognition/models/parseq.py +134 -0
  44. onnxtr/models/recognition/models/sar.py +134 -0
  45. onnxtr/models/recognition/models/vitstr.py +166 -0
  46. onnxtr/models/recognition/predictor/__init__.py +1 -0
  47. onnxtr/models/recognition/predictor/_utils.py +86 -0
  48. onnxtr/models/recognition/predictor/base.py +79 -0
  49. onnxtr/models/recognition/utils.py +89 -0
  50. onnxtr/models/recognition/zoo.py +69 -0
  51. onnxtr/models/zoo.py +114 -0
  52. onnxtr/transforms/__init__.py +1 -0
  53. onnxtr/transforms/base.py +112 -0
  54. onnxtr/utils/__init__.py +4 -0
  55. onnxtr/utils/common_types.py +18 -0
  56. onnxtr/utils/data.py +126 -0
  57. onnxtr/utils/fonts.py +41 -0
  58. onnxtr/utils/geometry.py +498 -0
  59. onnxtr/utils/multithreading.py +50 -0
  60. onnxtr/utils/reconstitution.py +70 -0
  61. onnxtr/utils/repr.py +64 -0
  62. onnxtr/utils/visualization.py +291 -0
  63. onnxtr/utils/vocabs.py +71 -0
  64. onnxtr/version.py +1 -0
  65. onnxtr-0.1.0.dist-info/LICENSE +201 -0
  66. onnxtr-0.1.0.dist-info/METADATA +481 -0
  67. onnxtr-0.1.0.dist-info/RECORD +70 -0
  68. onnxtr-0.1.0.dist-info/WHEEL +5 -0
  69. onnxtr-0.1.0.dist-info/top_level.txt +2 -0
  70. onnxtr-0.1.0.dist-info/zip-safe +1 -0
@@ -0,0 +1,144 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
+
8
+ from typing import List, Union
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import pyclipper
13
+ from shapely.geometry import Polygon
14
+
15
+ from ..core import DetectionPostProcessor
16
+
17
+ __all__ = ["GeneralDetectionPostProcessor"]
18
+
19
+
20
+ class GeneralDetectionPostProcessor(DetectionPostProcessor):
21
+ """Implements a post processor for FAST model.
22
+
23
+ Args:
24
+ ----
25
+ bin_thresh: threshold used to binzarized p_map at inference time
26
+ box_thresh: minimal objectness score to consider a box
27
+ assume_straight_pages: whether the inputs were expected to have horizontal text elements
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ bin_thresh: float = 0.1,
33
+ box_thresh: float = 0.1,
34
+ assume_straight_pages: bool = True,
35
+ ) -> None:
36
+ super().__init__(box_thresh, bin_thresh, assume_straight_pages)
37
+ self.unclip_ratio = 1.5
38
+
39
+ def polygon_to_box(
40
+ self,
41
+ points: np.ndarray,
42
+ ) -> np.ndarray:
43
+ """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
44
+
45
+ Args:
46
+ ----
47
+ points: The first parameter.
48
+
49
+ Returns:
50
+ -------
51
+ a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
52
+ """
53
+ if not self.assume_straight_pages:
54
+ # Compute the rectangle polygon enclosing the raw polygon
55
+ rect = cv2.minAreaRect(points)
56
+ points = cv2.boxPoints(rect)
57
+ # Add 1 pixel to correct cv2 approx
58
+ area = (rect[1][0] + 1) * (1 + rect[1][1])
59
+ length = 2 * (rect[1][0] + rect[1][1]) + 2
60
+ else:
61
+ poly = Polygon(points)
62
+ area = poly.area
63
+ length = poly.length
64
+ distance = area * self.unclip_ratio / length # compute distance to expand polygon
65
+ offset = pyclipper.PyclipperOffset()
66
+ offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
67
+ _points = offset.Execute(distance)
68
+ # Take biggest stack of points
69
+ idx = 0
70
+ if len(_points) > 1:
71
+ max_size = 0
72
+ for _idx, p in enumerate(_points):
73
+ if len(p) > max_size:
74
+ idx = _idx
75
+ max_size = len(p)
76
+ # We ensure that _points can be correctly casted to a ndarray
77
+ _points = [_points[idx]]
78
+ expanded_points: np.ndarray = np.asarray(_points) # expand polygon
79
+ if len(expanded_points) < 1:
80
+ return None # type: ignore[return-value]
81
+ return (
82
+ cv2.boundingRect(expanded_points) # type: ignore[return-value]
83
+ if self.assume_straight_pages
84
+ else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
85
+ )
86
+
87
+ def bitmap_to_boxes(
88
+ self,
89
+ pred: np.ndarray,
90
+ bitmap: np.ndarray,
91
+ ) -> np.ndarray:
92
+ """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
93
+
94
+ Args:
95
+ ----
96
+ pred: Pred map from differentiable linknet output
97
+ bitmap: Bitmap map computed from pred (binarized)
98
+ angle_tol: Comparison tolerance of the angle with the median angle across the page
99
+ ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
100
+
101
+ Returns:
102
+ -------
103
+ np tensor boxes for the bitmap, each box is a 6-element list
104
+ containing x, y, w, h, alpha, score for the box
105
+ """
106
+ height, width = bitmap.shape[:2]
107
+ boxes: List[Union[np.ndarray, List[float]]] = []
108
+ # get contours from connected components on the bitmap
109
+ contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
110
+ for contour in contours:
111
+ # Check whether smallest enclosing bounding box is not too small
112
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
113
+ continue
114
+ # Compute objectness
115
+ if self.assume_straight_pages:
116
+ x, y, w, h = cv2.boundingRect(contour)
117
+ points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
118
+ score = self.box_score(pred, points, assume_straight_pages=True)
119
+ else:
120
+ score = self.box_score(pred, contour, assume_straight_pages=False)
121
+
122
+ if score < self.box_thresh: # remove polygons with a weak objectness
123
+ continue
124
+
125
+ if self.assume_straight_pages:
126
+ _box = self.polygon_to_box(points)
127
+ else:
128
+ _box = self.polygon_to_box(np.squeeze(contour))
129
+
130
+ if self.assume_straight_pages:
131
+ # compute relative polygon to get rid of img shape
132
+ x, y, w, h = _box
133
+ xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
134
+ boxes.append([xmin, ymin, xmax, ymax, score])
135
+ else:
136
+ # compute relative box to get rid of img shape
137
+ _box[:, 0] /= width
138
+ _box[:, 1] /= height
139
+ boxes.append(_box)
140
+
141
+ if not self.assume_straight_pages:
142
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
143
+ else:
144
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
@@ -0,0 +1 @@
1
+ from .base import *
@@ -0,0 +1,54 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List, Tuple, Union
7
+
8
+ import numpy as np
9
+
10
+ from onnxtr.models.preprocessor import PreProcessor
11
+ from onnxtr.utils.repr import NestedObject
12
+
13
+ __all__ = ["DetectionPredictor"]
14
+
15
+
16
+ class DetectionPredictor(NestedObject):
17
+ """Implements an object able to localize text elements in a document
18
+
19
+ Args:
20
+ ----
21
+ pre_processor: transform inputs for easier batched model inference
22
+ model: core detection architecture
23
+ """
24
+
25
+ _children_names: List[str] = ["pre_processor", "model"]
26
+
27
+ def __init__(
28
+ self,
29
+ pre_processor: PreProcessor,
30
+ model: Any,
31
+ ) -> None:
32
+ self.pre_processor = pre_processor
33
+ self.model = model
34
+
35
+ def __call__(
36
+ self,
37
+ pages: List[np.ndarray],
38
+ return_maps: bool = False,
39
+ **kwargs: Any,
40
+ ) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]:
41
+ # Dimension check
42
+ if any(page.ndim != 3 for page in pages):
43
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
44
+
45
+ processed_batches = self.pre_processor(pages)
46
+ predicted_batches = [
47
+ self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
48
+ ]
49
+
50
+ preds = [pred for batch in predicted_batches for pred in batch["preds"]]
51
+ if return_maps:
52
+ seg_maps = [pred for batch in predicted_batches for pred in batch["out_map"]]
53
+ return preds, seg_maps
54
+ return preds
@@ -0,0 +1,73 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any
7
+
8
+ from .. import detection
9
+ from ..preprocessor import PreProcessor
10
+ from .predictor import DetectionPredictor
11
+
12
+ __all__ = ["detection_predictor"]
13
+
14
+ ARCHS = [
15
+ "db_resnet34",
16
+ "db_resnet50",
17
+ "db_mobilenet_v3_large",
18
+ "linknet_resnet18",
19
+ "linknet_resnet34",
20
+ "linknet_resnet50",
21
+ "fast_tiny",
22
+ "fast_small",
23
+ "fast_base",
24
+ ]
25
+
26
+
27
+ def _predictor(arch: Any, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
28
+ if isinstance(arch, str):
29
+ if arch not in ARCHS:
30
+ raise ValueError(f"unknown architecture '{arch}'")
31
+
32
+ _model = detection.__dict__[arch](assume_straight_pages=assume_straight_pages)
33
+ else:
34
+ if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
35
+ raise ValueError(f"unknown architecture: {type(arch)}")
36
+
37
+ _model = arch
38
+ _model.postprocessor.assume_straight_pages = assume_straight_pages
39
+
40
+ kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
41
+ kwargs["std"] = kwargs.get("std", _model.cfg["std"])
42
+ kwargs["batch_size"] = kwargs.get("batch_size", 4)
43
+ predictor = DetectionPredictor(
44
+ PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
45
+ _model,
46
+ )
47
+ return predictor
48
+
49
+
50
+ def detection_predictor(
51
+ arch: Any = "fast_base",
52
+ assume_straight_pages: bool = True,
53
+ **kwargs: Any,
54
+ ) -> DetectionPredictor:
55
+ """Text detection architecture.
56
+
57
+ >>> import numpy as np
58
+ >>> from onnxtr.models import detection_predictor
59
+ >>> model = detection_predictor(arch='db_resnet50')
60
+ >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
61
+ >>> out = model([input_page])
62
+
63
+ Args:
64
+ ----
65
+ arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
66
+ assume_straight_pages: If True, fit straight boxes to the page
67
+ **kwargs: optional keyword arguments passed to the architecture
68
+
69
+ Returns:
70
+ -------
71
+ Detection predictor
72
+ """
73
+ return _predictor(arch, assume_straight_pages, **kwargs)
@@ -0,0 +1,50 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List, Union
7
+
8
+ import numpy as np
9
+ import onnxruntime
10
+
11
+ from onnxtr.utils.data import download_from_url
12
+ from onnxtr.utils.geometry import shape_translate
13
+
14
+
15
+ class Engine:
16
+ """Implements an abstract class for the engine of a model
17
+
18
+ Args:
19
+ ----
20
+ url: the url to use to download a model if needed
21
+ providers: list of providers to use for inference
22
+ **kwargs: additional arguments to be passed to `download_from_url`
23
+ """
24
+
25
+ def __init__(
26
+ self, url: str, providers: List[str] = ["CPUExecutionProvider", "CUDAExecutionProvider"], **kwargs: Any
27
+ ) -> None:
28
+ archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
29
+ self.runtime = onnxruntime.InferenceSession(archive_path, providers=providers)
30
+ self.runtime_inputs = self.runtime.get_inputs()[0]
31
+ self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3
32
+ self.fixed_batch_size: Union[int, str] = self.runtime_inputs.shape[
33
+ 0
34
+ ] # mostly possible with tensorflow exported models
35
+ self.output_name = [output.name for output in self.runtime.get_outputs()]
36
+
37
+ def run(self, inputs: np.ndarray) -> np.ndarray:
38
+ if self.tf_exported:
39
+ inputs = shape_translate(inputs, format="BHWC") # sanity check
40
+ else:
41
+ inputs = shape_translate(inputs, format="BCHW")
42
+ if isinstance(self.fixed_batch_size, int) and self.fixed_batch_size != 0: # dynamic batch size is a string
43
+ inputs = np.broadcast_to(inputs, (self.fixed_batch_size, *inputs.shape))
44
+ # combine the results
45
+ logits = np.concatenate(
46
+ [self.runtime.run(self.output_name, {"input": batch})[0] for batch in inputs], axis=0
47
+ )
48
+ else:
49
+ logits = self.runtime.run(self.output_name, {"input": inputs})[0]
50
+ return shape_translate(logits, format="BHWC")
@@ -0,0 +1 @@
1
+ from .predictor import *
@@ -0,0 +1,175 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple
7
+
8
+ import numpy as np
9
+
10
+ from onnxtr.models.builder import DocumentBuilder
11
+ from onnxtr.utils.geometry import extract_crops, extract_rcrops
12
+
13
+ from .._utils import rectify_crops, rectify_loc_preds
14
+ from ..classification import crop_orientation_predictor
15
+ from ..classification.predictor import OrientationPredictor
16
+ from ..detection.zoo import ARCHS as DETECTION_ARCHS
17
+ from ..recognition.zoo import ARCHS as RECOGNITION_ARCHS
18
+
19
+ __all__ = ["_OCRPredictor"]
20
+
21
+
22
+ class _OCRPredictor:
23
+ """Implements an object able to localize and identify text elements in a set of documents
24
+
25
+ Args:
26
+ ----
27
+ assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
28
+ without rotated textual elements.
29
+ straighten_pages: if True, estimates the page general orientation based on the median line orientation.
30
+ Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
31
+ accordingly. Doing so will improve performances for documents with page-uniform rotations.
32
+ preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
33
+ symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
34
+ **kwargs: keyword args of `DocumentBuilder`
35
+ """
36
+
37
+ crop_orientation_predictor: Optional[OrientationPredictor]
38
+
39
+ def __init__(
40
+ self,
41
+ assume_straight_pages: bool = True,
42
+ straighten_pages: bool = False,
43
+ preserve_aspect_ratio: bool = True,
44
+ symmetric_pad: bool = True,
45
+ **kwargs: Any,
46
+ ) -> None:
47
+ self.assume_straight_pages = assume_straight_pages
48
+ self.straighten_pages = straighten_pages
49
+ self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor()
50
+ self.doc_builder = DocumentBuilder(**kwargs)
51
+ self.preserve_aspect_ratio = preserve_aspect_ratio
52
+ self.symmetric_pad = symmetric_pad
53
+ self.hooks: List[Callable] = []
54
+
55
+ @staticmethod
56
+ def _generate_crops(
57
+ pages: List[np.ndarray],
58
+ loc_preds: List[np.ndarray],
59
+ channels_last: bool,
60
+ assume_straight_pages: bool = False,
61
+ ) -> List[List[np.ndarray]]:
62
+ extraction_fn = extract_crops if assume_straight_pages else extract_rcrops
63
+
64
+ crops = [
65
+ extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator]
66
+ for page, _boxes in zip(pages, loc_preds)
67
+ ]
68
+ return crops
69
+
70
+ @staticmethod
71
+ def _prepare_crops(
72
+ pages: List[np.ndarray],
73
+ loc_preds: List[np.ndarray],
74
+ channels_last: bool,
75
+ assume_straight_pages: bool = False,
76
+ ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
77
+ crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
78
+
79
+ # Avoid sending zero-sized crops
80
+ is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
81
+ crops = [
82
+ [crop for crop, _kept in zip(page_crops, page_kept) if _kept]
83
+ for page_crops, page_kept in zip(crops, is_kept)
84
+ ]
85
+ loc_preds = [_boxes[_kept] for _boxes, _kept in zip(loc_preds, is_kept)]
86
+
87
+ return crops, loc_preds
88
+
89
+ def _rectify_crops(
90
+ self,
91
+ crops: List[List[np.ndarray]],
92
+ loc_preds: List[np.ndarray],
93
+ ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]:
94
+ # Work at a page level
95
+ orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
96
+ rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
97
+ rect_loc_preds = [
98
+ rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds
99
+ for page_loc_preds, orientation in zip(loc_preds, orientations)
100
+ ]
101
+ # Flatten to list of tuples with (value, confidence)
102
+ crop_orientations = [
103
+ (orientation, prob)
104
+ for page_classes, page_probs in zip(classes, probs)
105
+ for orientation, prob in zip(page_classes, page_probs)
106
+ ]
107
+ return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
108
+
109
+ def _remove_padding(
110
+ self,
111
+ pages: List[np.ndarray],
112
+ loc_preds: List[np.ndarray],
113
+ ) -> List[np.ndarray]:
114
+ if self.preserve_aspect_ratio:
115
+ # Rectify loc_preds to remove padding
116
+ rectified_preds = []
117
+ for page, loc_pred in zip(pages, loc_preds):
118
+ h, w = page.shape[0], page.shape[1]
119
+ if h > w:
120
+ # y unchanged, dilate x coord
121
+ if self.symmetric_pad:
122
+ if self.assume_straight_pages:
123
+ loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
124
+ else:
125
+ loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
126
+ else:
127
+ if self.assume_straight_pages:
128
+ loc_pred[:, [0, 2]] *= h / w
129
+ else:
130
+ loc_pred[:, :, 0] *= h / w
131
+ elif w > h:
132
+ # x unchanged, dilate y coord
133
+ if self.symmetric_pad:
134
+ if self.assume_straight_pages:
135
+ loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
136
+ else:
137
+ loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
138
+ else:
139
+ if self.assume_straight_pages:
140
+ loc_pred[:, [1, 3]] *= w / h
141
+ else:
142
+ loc_pred[:, :, 1] *= w / h
143
+ rectified_preds.append(loc_pred)
144
+ return rectified_preds
145
+ return loc_preds
146
+
147
+ @staticmethod
148
+ def _process_predictions(
149
+ loc_preds: List[np.ndarray],
150
+ word_preds: List[Tuple[str, float]],
151
+ crop_orientations: List[Dict[str, Any]],
152
+ ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]:
153
+ text_preds = []
154
+ crop_orientation_preds = []
155
+ if len(loc_preds) > 0:
156
+ # Text & crop orientation predictions at page level
157
+ _idx = 0
158
+ for page_boxes in loc_preds:
159
+ text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]])
160
+ crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]])
161
+ _idx += page_boxes.shape[0]
162
+
163
+ return loc_preds, text_preds, crop_orientation_preds
164
+
165
+ def add_hook(self, hook: Callable) -> None:
166
+ """Add a hook to the predictor
167
+
168
+ Args:
169
+ ----
170
+ hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
171
+ """
172
+ self.hooks.append(hook)
173
+
174
+ def list_archs(self) -> Dict[str, List[str]]:
175
+ return {"detection_archs": DETECTION_ARCHS, "recognition_archs": RECOGNITION_ARCHS}
@@ -0,0 +1,145 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List
7
+
8
+ import numpy as np
9
+
10
+ from onnxtr.io.elements import Document
11
+ from onnxtr.models._utils import estimate_orientation, get_language
12
+ from onnxtr.models.detection.predictor import DetectionPredictor
13
+ from onnxtr.models.recognition.predictor import RecognitionPredictor
14
+ from onnxtr.utils.geometry import rotate_image
15
+ from onnxtr.utils.repr import NestedObject
16
+
17
+ from .base import _OCRPredictor
18
+
19
+ __all__ = ["OCRPredictor"]
20
+
21
+
22
+ class OCRPredictor(NestedObject, _OCRPredictor):
23
+ """Implements an object able to localize and identify text elements in a set of documents
24
+
25
+ Args:
26
+ ----
27
+ det_predictor: detection module
28
+ reco_predictor: recognition module
29
+ assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
30
+ without rotated textual elements.
31
+ straighten_pages: if True, estimates the page general orientation based on the median line orientation.
32
+ Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
33
+ accordingly. Doing so will improve performances for documents with page-uniform rotations.
34
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
35
+ page. Doing so will slightly deteriorate the overall latency.
36
+ detect_language: if True, the language prediction will be added to the predictions for each
37
+ page. Doing so will slightly deteriorate the overall latency.
38
+ **kwargs: keyword args of `DocumentBuilder`
39
+ """
40
+
41
+ _children_names = ["det_predictor", "reco_predictor", "doc_builder"]
42
+
43
+ def __init__(
44
+ self,
45
+ det_predictor: DetectionPredictor,
46
+ reco_predictor: RecognitionPredictor,
47
+ assume_straight_pages: bool = True,
48
+ straighten_pages: bool = False,
49
+ preserve_aspect_ratio: bool = True,
50
+ symmetric_pad: bool = True,
51
+ detect_orientation: bool = False,
52
+ detect_language: bool = False,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ self.det_predictor = det_predictor
56
+ self.reco_predictor = reco_predictor
57
+ _OCRPredictor.__init__(
58
+ self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
59
+ )
60
+ self.detect_orientation = detect_orientation
61
+ self.detect_language = detect_language
62
+
63
+ def __call__(
64
+ self,
65
+ pages: List[np.ndarray],
66
+ **kwargs: Any,
67
+ ) -> Document:
68
+ # Dimension check
69
+ if any(page.ndim != 3 for page in pages):
70
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
71
+
72
+ origin_page_shapes = [page.shape[:2] for page in pages]
73
+
74
+ # Localize text elements
75
+ loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
76
+
77
+ # Detect document rotation and rotate pages
78
+ seg_maps = [
79
+ np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
80
+ for out_map in out_maps
81
+ ]
82
+ if self.detect_orientation:
83
+ origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
84
+ orientations = [
85
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
86
+ ]
87
+ else:
88
+ orientations = None
89
+ if self.straighten_pages:
90
+ origin_page_orientations = (
91
+ origin_page_orientations
92
+ if self.detect_orientation
93
+ else [estimate_orientation(seq_map) for seq_map in seg_maps]
94
+ )
95
+ pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
96
+ # forward again to get predictions on straight pages
97
+ loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
98
+
99
+ loc_preds = [loc_pred[0] for loc_pred in loc_preds]
100
+
101
+ # Rectify crops if aspect ratio
102
+ loc_preds = self._remove_padding(pages, loc_preds)
103
+
104
+ # Apply hooks to loc_preds if any
105
+ for hook in self.hooks:
106
+ loc_preds = hook(loc_preds)
107
+
108
+ # Crop images
109
+ crops, loc_preds = self._prepare_crops(
110
+ pages,
111
+ loc_preds, # type: ignore[arg-type]
112
+ channels_last=True,
113
+ assume_straight_pages=self.assume_straight_pages,
114
+ )
115
+ # Rectify crop orientation and get crop orientation predictions
116
+ crop_orientations: Any = []
117
+ if not self.assume_straight_pages:
118
+ crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
119
+ crop_orientations = [
120
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
121
+ ]
122
+
123
+ # Identify character sequences
124
+ word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
125
+ if not crop_orientations:
126
+ crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
127
+
128
+ boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
129
+
130
+ if self.detect_language:
131
+ languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
132
+ languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
133
+ else:
134
+ languages_dict = None
135
+
136
+ out = self.doc_builder(
137
+ pages,
138
+ boxes,
139
+ text_preds,
140
+ origin_page_shapes, # type: ignore[arg-type]
141
+ crop_orientations,
142
+ orientations,
143
+ languages_dict,
144
+ )
145
+ return out
@@ -0,0 +1 @@
1
+ from .base import *