onnxtr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. onnxtr/__init__.py +2 -0
  2. onnxtr/contrib/__init__.py +0 -0
  3. onnxtr/contrib/artefacts.py +131 -0
  4. onnxtr/contrib/base.py +105 -0
  5. onnxtr/file_utils.py +33 -0
  6. onnxtr/io/__init__.py +5 -0
  7. onnxtr/io/elements.py +455 -0
  8. onnxtr/io/html.py +28 -0
  9. onnxtr/io/image.py +56 -0
  10. onnxtr/io/pdf.py +42 -0
  11. onnxtr/io/reader.py +85 -0
  12. onnxtr/models/__init__.py +4 -0
  13. onnxtr/models/_utils.py +141 -0
  14. onnxtr/models/builder.py +355 -0
  15. onnxtr/models/classification/__init__.py +2 -0
  16. onnxtr/models/classification/models/__init__.py +1 -0
  17. onnxtr/models/classification/models/mobilenet.py +120 -0
  18. onnxtr/models/classification/predictor/__init__.py +1 -0
  19. onnxtr/models/classification/predictor/base.py +57 -0
  20. onnxtr/models/classification/zoo.py +76 -0
  21. onnxtr/models/detection/__init__.py +2 -0
  22. onnxtr/models/detection/core.py +101 -0
  23. onnxtr/models/detection/models/__init__.py +3 -0
  24. onnxtr/models/detection/models/differentiable_binarization.py +159 -0
  25. onnxtr/models/detection/models/fast.py +160 -0
  26. onnxtr/models/detection/models/linknet.py +160 -0
  27. onnxtr/models/detection/postprocessor/__init__.py +0 -0
  28. onnxtr/models/detection/postprocessor/base.py +144 -0
  29. onnxtr/models/detection/predictor/__init__.py +1 -0
  30. onnxtr/models/detection/predictor/base.py +54 -0
  31. onnxtr/models/detection/zoo.py +73 -0
  32. onnxtr/models/engine.py +50 -0
  33. onnxtr/models/predictor/__init__.py +1 -0
  34. onnxtr/models/predictor/base.py +175 -0
  35. onnxtr/models/predictor/predictor.py +145 -0
  36. onnxtr/models/preprocessor/__init__.py +1 -0
  37. onnxtr/models/preprocessor/base.py +118 -0
  38. onnxtr/models/recognition/__init__.py +2 -0
  39. onnxtr/models/recognition/core.py +28 -0
  40. onnxtr/models/recognition/models/__init__.py +5 -0
  41. onnxtr/models/recognition/models/crnn.py +226 -0
  42. onnxtr/models/recognition/models/master.py +145 -0
  43. onnxtr/models/recognition/models/parseq.py +134 -0
  44. onnxtr/models/recognition/models/sar.py +134 -0
  45. onnxtr/models/recognition/models/vitstr.py +166 -0
  46. onnxtr/models/recognition/predictor/__init__.py +1 -0
  47. onnxtr/models/recognition/predictor/_utils.py +86 -0
  48. onnxtr/models/recognition/predictor/base.py +79 -0
  49. onnxtr/models/recognition/utils.py +89 -0
  50. onnxtr/models/recognition/zoo.py +69 -0
  51. onnxtr/models/zoo.py +114 -0
  52. onnxtr/transforms/__init__.py +1 -0
  53. onnxtr/transforms/base.py +112 -0
  54. onnxtr/utils/__init__.py +4 -0
  55. onnxtr/utils/common_types.py +18 -0
  56. onnxtr/utils/data.py +126 -0
  57. onnxtr/utils/fonts.py +41 -0
  58. onnxtr/utils/geometry.py +498 -0
  59. onnxtr/utils/multithreading.py +50 -0
  60. onnxtr/utils/reconstitution.py +70 -0
  61. onnxtr/utils/repr.py +64 -0
  62. onnxtr/utils/visualization.py +291 -0
  63. onnxtr/utils/vocabs.py +71 -0
  64. onnxtr/version.py +1 -0
  65. onnxtr-0.1.0.dist-info/LICENSE +201 -0
  66. onnxtr-0.1.0.dist-info/METADATA +481 -0
  67. onnxtr-0.1.0.dist-info/RECORD +70 -0
  68. onnxtr-0.1.0.dist-info/WHEEL +5 -0
  69. onnxtr-0.1.0.dist-info/top_level.txt +2 -0
  70. onnxtr-0.1.0.dist-info/zip-safe +1 -0
@@ -0,0 +1,76 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List
7
+
8
+ from .. import classification
9
+ from ..preprocessor import PreProcessor
10
+ from .predictor import OrientationPredictor
11
+
12
+ __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
13
+
14
+ ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
15
+
16
+
17
+ def _orientation_predictor(arch: str, **kwargs: Any) -> OrientationPredictor:
18
+ if arch not in ORIENTATION_ARCHS:
19
+ raise ValueError(f"unknown architecture '{arch}'")
20
+
21
+ # Load directly classifier from backbone
22
+ _model = classification.__dict__[arch]()
23
+ kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
24
+ kwargs["std"] = kwargs.get("std", _model.cfg["std"])
25
+ kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
26
+ input_shape = _model.cfg["input_shape"][1:]
27
+ predictor = OrientationPredictor(
28
+ PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
29
+ )
30
+ return predictor
31
+
32
+
33
+ def crop_orientation_predictor(
34
+ arch: Any = "mobilenet_v3_small_crop_orientation", **kwargs: Any
35
+ ) -> OrientationPredictor:
36
+ """Crop orientation classification architecture.
37
+
38
+ >>> import numpy as np
39
+ >>> from onnxtr.models import crop_orientation_predictor
40
+ >>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation')
41
+ >>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8)
42
+ >>> out = model([input_crop])
43
+
44
+ Args:
45
+ ----
46
+ arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
47
+ **kwargs: keyword arguments to be passed to the OrientationPredictor
48
+
49
+ Returns:
50
+ -------
51
+ OrientationPredictor
52
+ """
53
+ return _orientation_predictor(arch, **kwargs)
54
+
55
+
56
+ def page_orientation_predictor(
57
+ arch: Any = "mobilenet_v3_small_page_orientation", **kwargs: Any
58
+ ) -> OrientationPredictor:
59
+ """Page orientation classification architecture.
60
+
61
+ >>> import numpy as np
62
+ >>> from onnxtr.models import page_orientation_predictor
63
+ >>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation')
64
+ >>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8)
65
+ >>> out = model([input_page])
66
+
67
+ Args:
68
+ ----
69
+ arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
70
+ **kwargs: keyword arguments to be passed to the OrientationPredictor
71
+
72
+ Returns:
73
+ -------
74
+ OrientationPredictor
75
+ """
76
+ return _orientation_predictor(arch, **kwargs)
@@ -0,0 +1,2 @@
1
+ from .models import *
2
+ from .zoo import *
@@ -0,0 +1,101 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import List
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from onnxtr.utils.repr import NestedObject
12
+
13
+ __all__ = ["DetectionPostProcessor"]
14
+
15
+
16
+ class DetectionPostProcessor(NestedObject):
17
+ """Abstract class to postprocess the raw output of the model
18
+
19
+ Args:
20
+ ----
21
+ box_thresh (float): minimal objectness score to consider a box
22
+ bin_thresh (float): threshold to apply to segmentation raw heatmap
23
+ assume straight_pages (bool): if True, fit straight boxes only
24
+ """
25
+
26
+ def __init__(self, box_thresh: float = 0.5, bin_thresh: float = 0.5, assume_straight_pages: bool = True) -> None:
27
+ self.box_thresh = box_thresh
28
+ self.bin_thresh = bin_thresh
29
+ self.assume_straight_pages = assume_straight_pages
30
+ self._opening_kernel: np.ndarray = np.ones((3, 3), dtype=np.uint8)
31
+
32
+ def extra_repr(self) -> str:
33
+ return f"bin_thresh={self.bin_thresh}, box_thresh={self.box_thresh}"
34
+
35
+ @staticmethod
36
+ def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool = True) -> float:
37
+ """Compute the confidence score for a polygon : mean of the p values on the polygon
38
+
39
+ Args:
40
+ ----
41
+ pred (np.ndarray): p map returned by the model
42
+ points: coordinates of the polygon
43
+ assume_straight_pages: if True, fit straight boxes only
44
+
45
+ Returns:
46
+ -------
47
+ polygon objectness
48
+ """
49
+ h, w = pred.shape[:2]
50
+
51
+ if assume_straight_pages:
52
+ xmin = np.clip(np.floor(points[:, 0].min()).astype(np.int32), 0, w - 1)
53
+ xmax = np.clip(np.ceil(points[:, 0].max()).astype(np.int32), 0, w - 1)
54
+ ymin = np.clip(np.floor(points[:, 1].min()).astype(np.int32), 0, h - 1)
55
+ ymax = np.clip(np.ceil(points[:, 1].max()).astype(np.int32), 0, h - 1)
56
+ return pred[ymin : ymax + 1, xmin : xmax + 1].mean()
57
+
58
+ else:
59
+ mask: np.ndarray = np.zeros((h, w), np.int32)
60
+ cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload]
61
+ product = pred * mask
62
+ return np.sum(product) / np.count_nonzero(product)
63
+
64
+ def bitmap_to_boxes(
65
+ self,
66
+ pred: np.ndarray,
67
+ bitmap: np.ndarray,
68
+ ) -> np.ndarray:
69
+ raise NotImplementedError
70
+
71
+ def __call__(
72
+ self,
73
+ proba_map,
74
+ ) -> List[List[np.ndarray]]:
75
+ """Performs postprocessing for a list of model outputs
76
+
77
+ Args:
78
+ ----
79
+ proba_map: probability map of shape (N, H, W, C)
80
+
81
+ Returns:
82
+ -------
83
+ list of N class predictions (for each input sample), where each class predictions is a list of C tensors
84
+ of shape (*, 5) or (*, 6)
85
+ """
86
+ if proba_map.ndim != 4:
87
+ raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.")
88
+
89
+ # Erosion + dilation on the binary map
90
+ bin_map = [
91
+ [
92
+ cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel)
93
+ for idx in range(proba_map.shape[-1])
94
+ ]
95
+ for bmap in (proba_map >= self.bin_thresh).astype(np.uint8)
96
+ ]
97
+
98
+ return [
99
+ [self.bitmap_to_boxes(pmaps[..., idx], bmaps[idx]) for idx in range(proba_map.shape[-1])]
100
+ for pmaps, bmaps in zip(proba_map, bin_map)
101
+ ]
@@ -0,0 +1,3 @@
1
+ from .fast import *
2
+ from .differentiable_binarization import *
3
+ from .linknet import *
@@ -0,0 +1,159 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Dict, Optional
7
+
8
+ import numpy as np
9
+ from scipy.special import expit
10
+
11
+ from ...engine import Engine
12
+ from ..postprocessor.base import GeneralDetectionPostProcessor
13
+
14
+ __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
15
+
16
+
17
+ default_cfgs: Dict[str, Dict[str, Any]] = {
18
+ "db_resnet50": {
19
+ "input_shape": (3, 1024, 1024),
20
+ "mean": (0.798, 0.785, 0.772),
21
+ "std": (0.264, 0.2749, 0.287),
22
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet50-69ba0015.onnx",
23
+ },
24
+ "db_resnet34": {
25
+ "input_shape": (3, 1024, 1024),
26
+ "mean": (0.798, 0.785, 0.772),
27
+ "std": (0.264, 0.2749, 0.287),
28
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet34-b4873198.onnx",
29
+ },
30
+ "db_mobilenet_v3_large": {
31
+ "input_shape": (3, 1024, 1024),
32
+ "mean": (0.798, 0.785, 0.772),
33
+ "std": (0.264, 0.2749, 0.287),
34
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx",
35
+ },
36
+ }
37
+
38
+
39
+ class DBNet(Engine):
40
+ """DBNet Onnx loader
41
+
42
+ Args:
43
+ ----
44
+ model_path: path or url to onnx model file
45
+ bin_thresh: threshold for binarization of the output feature map
46
+ box_thresh: minimal objectness score to consider a box
47
+ assume_straight_pages: if True, fit straight bounding boxes only
48
+ cfg: the configuration dict of the model
49
+ **kwargs: additional arguments to be passed to `Engine`
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ model_path,
55
+ bin_thresh: float = 0.3,
56
+ box_thresh: float = 0.1,
57
+ assume_straight_pages: bool = True,
58
+ cfg: Optional[Dict[str, Any]] = None,
59
+ **kwargs: Any,
60
+ ) -> None:
61
+ super().__init__(url=model_path, **kwargs)
62
+ self.cfg = cfg
63
+ self.assume_straight_pages = assume_straight_pages
64
+ self.postprocessor = GeneralDetectionPostProcessor(
65
+ assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
66
+ )
67
+
68
+ def __call__(
69
+ self,
70
+ x: np.ndarray,
71
+ return_model_output: bool = False,
72
+ **kwargs: Any,
73
+ ) -> Dict[str, Any]:
74
+ logits = self.run(x)
75
+
76
+ out: Dict[str, Any] = {}
77
+
78
+ prob_map = expit(logits)
79
+ if return_model_output:
80
+ out["out_map"] = prob_map
81
+
82
+ out["preds"] = self.postprocessor(prob_map)
83
+
84
+ return out
85
+
86
+
87
+ def _dbnet(
88
+ arch: str,
89
+ model_path: str,
90
+ **kwargs: Any,
91
+ ) -> DBNet:
92
+ # Build the model
93
+ return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
94
+
95
+
96
+ def db_resnet34(model_path: str = default_cfgs["db_resnet34"]["url"], **kwargs: Any) -> DBNet:
97
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
98
+ <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
99
+
100
+ >>> import numpy as np
101
+ >>> from onnxtr.models import db_resnet34
102
+ >>> model = db_resnet34()
103
+ >>> input_tensor = np.random.rand(1, 3, 1024, 1024)
104
+ >>> out = model(input_tensor)
105
+
106
+ Args:
107
+ ----
108
+ model_path: path to onnx model file, defaults to url in default_cfgs
109
+ **kwargs: keyword arguments of the DBNet architecture
110
+
111
+ Returns:
112
+ -------
113
+ text detection architecture
114
+ """
115
+ return _dbnet("db_resnet34", model_path, **kwargs)
116
+
117
+
118
+ def db_resnet50(model_path: str = default_cfgs["db_resnet50"]["url"], **kwargs: Any) -> DBNet:
119
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
120
+ <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
121
+
122
+ >>> import numpy as np
123
+ >>> from onnxtr.models import db_resnet50
124
+ >>> model = db_resnet50()
125
+ >>> input_tensor = np.random.rand(1, 3, 1024, 1024)
126
+ >>> out = model(input_tensor)
127
+
128
+ Args:
129
+ ----
130
+ model_path: path to onnx model file, defaults to url in default_cfgs
131
+ **kwargs: keyword arguments of the DBNet architecture
132
+
133
+ Returns:
134
+ -------
135
+ text detection architecture
136
+ """
137
+ return _dbnet("db_resnet50", model_path, **kwargs)
138
+
139
+
140
+ def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], **kwargs: Any) -> DBNet:
141
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
142
+ <https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
143
+
144
+ >>> import numpy as np
145
+ >>> from onnxtr.models import db_mobilenet_v3_large
146
+ >>> model = db_mobilenet_v3_large()
147
+ >>> input_tensor = np.random.rand(1, 3, 1024, 1024)
148
+ >>> out = model(input_tensor)
149
+
150
+ Args:
151
+ ----
152
+ model_path: path to onnx model file, defaults to url in default_cfgs
153
+ **kwargs: keyword arguments of the DBNet architecture
154
+
155
+ Returns:
156
+ -------
157
+ text detection architecture
158
+ """
159
+ return _dbnet("db_mobilenet_v3_large", model_path, **kwargs)
@@ -0,0 +1,160 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Dict, Optional
7
+
8
+ import numpy as np
9
+ from scipy.special import expit
10
+
11
+ from ...engine import Engine
12
+ from ..postprocessor.base import GeneralDetectionPostProcessor
13
+
14
+ __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
15
+
16
+
17
+ default_cfgs: Dict[str, Dict[str, Any]] = {
18
+ "fast_tiny": {
19
+ "input_shape": (3, 1024, 1024),
20
+ "mean": (0.798, 0.785, 0.772),
21
+ "std": (0.264, 0.2749, 0.287),
22
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/rep_fast_tiny-28867779.onnx",
23
+ },
24
+ "fast_small": {
25
+ "input_shape": (3, 1024, 1024),
26
+ "mean": (0.798, 0.785, 0.772),
27
+ "std": (0.264, 0.2749, 0.287),
28
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/rep_fast_small-10428b70.onnx",
29
+ },
30
+ "fast_base": {
31
+ "input_shape": (3, 1024, 1024),
32
+ "mean": (0.798, 0.785, 0.772),
33
+ "std": (0.264, 0.2749, 0.287),
34
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/rep_fast_base-1b89ebf9.onnx",
35
+ },
36
+ }
37
+
38
+
39
+ class FAST(Engine):
40
+ """FAST Onnx loader
41
+
42
+ Args:
43
+ ----
44
+ model_path: path or url to onnx model file
45
+ bin_thresh: threshold for binarization of the output feature map
46
+ box_thresh: minimal objectness score to consider a box
47
+ assume_straight_pages: if True, fit straight bounding boxes only
48
+ cfg: the configuration dict of the model
49
+ **kwargs: additional arguments to be passed to `Engine`
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ model_path: str,
55
+ bin_thresh: float = 0.1,
56
+ box_thresh: float = 0.1,
57
+ assume_straight_pages: bool = True,
58
+ cfg: Optional[Dict[str, Any]] = None,
59
+ **kwargs: Any,
60
+ ) -> None:
61
+ super().__init__(url=model_path, **kwargs)
62
+ self.cfg = cfg
63
+ self.assume_straight_pages = assume_straight_pages
64
+
65
+ self.postprocessor = GeneralDetectionPostProcessor(
66
+ assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
67
+ )
68
+
69
+ def __call__(
70
+ self,
71
+ x: np.ndarray,
72
+ return_model_output: bool = False,
73
+ **kwargs: Any,
74
+ ) -> Dict[str, Any]:
75
+ logits = self.run(x)
76
+
77
+ out: Dict[str, Any] = {}
78
+
79
+ prob_map = expit(logits)
80
+ if return_model_output:
81
+ out["out_map"] = prob_map
82
+
83
+ out["preds"] = self.postprocessor(prob_map)
84
+
85
+ return out
86
+
87
+
88
+ def _fast(
89
+ arch: str,
90
+ model_path: str,
91
+ **kwargs: Any,
92
+ ) -> FAST:
93
+ # Build the model
94
+ return FAST(model_path, cfg=default_cfgs[arch], **kwargs)
95
+
96
+
97
+ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], **kwargs: Any) -> FAST:
98
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
99
+ <https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
100
+
101
+ >>> import numpy as np
102
+ >>> from onnxtr.models import fast_tiny
103
+ >>> model = fast_tiny()
104
+ >>> input_tensor = np.random.rand(1, 3, 1024, 1024)
105
+ >>> out = model(input_tensor)
106
+
107
+ Args:
108
+ ----
109
+ model_path: path to onnx model file, defaults to url in default_cfgs
110
+ **kwargs: keyword arguments of the DBNet architecture
111
+
112
+ Returns:
113
+ -------
114
+ text detection architecture
115
+ """
116
+ return _fast("fast_tiny", model_path, **kwargs)
117
+
118
+
119
+ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], **kwargs: Any) -> FAST:
120
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
121
+ <https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
122
+
123
+ >>> import numpy as np
124
+ >>> from onnxtr.models import fast_small
125
+ >>> model = fast_small()
126
+ >>> input_tensor = np.random.rand(1, 3, 1024, 1024)
127
+ >>> out = model(input_tensor)
128
+
129
+ Args:
130
+ ----
131
+ model_path: path to onnx model file, defaults to url in default_cfgs
132
+ **kwargs: keyword arguments of the DBNet architecture
133
+
134
+ Returns:
135
+ -------
136
+ text detection architecture
137
+ """
138
+ return _fast("fast_small", model_path, **kwargs)
139
+
140
+
141
+ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], **kwargs: Any) -> FAST:
142
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
143
+ <https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
144
+
145
+ >>> import numpy as np
146
+ >>> from onnxtr.models import fast_base
147
+ >>> model = fast_base()
148
+ >>> input_tensor = np.random.rand(1, 3, 1024, 1024)
149
+ >>> out = model(input_tensor)
150
+
151
+ Args:
152
+ ----
153
+ model_path: path to onnx model file, defaults to url in default_cfgs
154
+ **kwargs: keyword arguments of the DBNet architecture
155
+
156
+ Returns:
157
+ -------
158
+ text detection architecture
159
+ """
160
+ return _fast("fast_base", model_path, **kwargs)
@@ -0,0 +1,160 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Dict, Optional
7
+
8
+ import numpy as np
9
+ from scipy.special import expit
10
+
11
+ from ...engine import Engine
12
+ from ..postprocessor.base import GeneralDetectionPostProcessor
13
+
14
+ __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
15
+
16
+
17
+ default_cfgs: Dict[str, Dict[str, Any]] = {
18
+ "linknet_resnet18": {
19
+ "input_shape": (3, 1024, 1024),
20
+ "mean": (0.798, 0.785, 0.772),
21
+ "std": (0.264, 0.2749, 0.287),
22
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet18-e0e0b9dc.onnx",
23
+ },
24
+ "linknet_resnet34": {
25
+ "input_shape": (3, 1024, 1024),
26
+ "mean": (0.798, 0.785, 0.772),
27
+ "std": (0.264, 0.2749, 0.287),
28
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet34-93e39a39.onnx",
29
+ },
30
+ "linknet_resnet50": {
31
+ "input_shape": (3, 1024, 1024),
32
+ "mean": (0.798, 0.785, 0.772),
33
+ "std": (0.264, 0.2749, 0.287),
34
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet50-15d8c4ec.onnx",
35
+ },
36
+ }
37
+
38
+
39
+ class LinkNet(Engine):
40
+ """LinkNet Onnx loader
41
+
42
+ Args:
43
+ ----
44
+ model_path: path or url to onnx model file
45
+ bin_thresh: threshold for binarization of the output feature map
46
+ box_thresh: minimal objectness score to consider a box
47
+ assume_straight_pages: if True, fit straight bounding boxes only
48
+ cfg: the configuration dict of the model
49
+ **kwargs: additional arguments to be passed to `Engine`
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ model_path: str,
55
+ bin_thresh: float = 0.1,
56
+ box_thresh: float = 0.1,
57
+ assume_straight_pages: bool = True,
58
+ cfg: Optional[Dict[str, Any]] = None,
59
+ **kwargs: Any,
60
+ ) -> None:
61
+ super().__init__(url=model_path, **kwargs)
62
+ self.cfg = cfg
63
+ self.assume_straight_pages = assume_straight_pages
64
+
65
+ self.postprocessor = GeneralDetectionPostProcessor(
66
+ assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
67
+ )
68
+
69
+ def __call__(
70
+ self,
71
+ x: np.ndarray,
72
+ return_model_output: bool = False,
73
+ **kwargs: Any,
74
+ ) -> Dict[str, Any]:
75
+ logits = self.run(x)
76
+
77
+ out: Dict[str, Any] = {}
78
+
79
+ prob_map = expit(logits)
80
+ if return_model_output:
81
+ out["out_map"] = prob_map
82
+
83
+ out["preds"] = self.postprocessor(prob_map)
84
+
85
+ return out
86
+
87
+
88
+ def _linknet(
89
+ arch: str,
90
+ model_path: str,
91
+ **kwargs: Any,
92
+ ) -> LinkNet:
93
+ # Build the model
94
+ return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs)
95
+
96
+
97
+ def linknet_resnet18(model_path: str = default_cfgs["linknet_resnet18"]["url"], **kwargs: Any) -> LinkNet:
98
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
99
+ <https://arxiv.org/pdf/1707.03718.pdf>`_.
100
+
101
+ >>> import numpy as np
102
+ >>> from onnxtr.models import linknet_resnet18
103
+ >>> model = linknet_resnet18()
104
+ >>> input_tensor = np.random.rand(1, 3, 1024, 1024)
105
+ >>> out = model(input_tensor)
106
+
107
+ Args:
108
+ ----
109
+ model_path: path to onnx model file, defaults to url in default_cfgs
110
+ **kwargs: keyword arguments of the LinkNet architecture
111
+
112
+ Returns:
113
+ -------
114
+ text detection architecture
115
+ """
116
+ return _linknet("linknet_resnet18", model_path, **kwargs)
117
+
118
+
119
+ def linknet_resnet34(model_path: str = default_cfgs["linknet_resnet34"]["url"], **kwargs: Any) -> LinkNet:
120
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
121
+ <https://arxiv.org/pdf/1707.03718.pdf>`_.
122
+
123
+ >>> import numpy as np
124
+ >>> from onnxtr.models import linknet_resnet34
125
+ >>> model = linknet_resnet34()
126
+ >>> input_tensor = np.random.rand(1, 3, 1024, 1024)
127
+ >>> out = model(input_tensor)
128
+
129
+ Args:
130
+ ----
131
+ model_path: path to onnx model file, defaults to url in default_cfgs
132
+ **kwargs: keyword arguments of the LinkNet architecture
133
+
134
+ Returns:
135
+ -------
136
+ text detection architecture
137
+ """
138
+ return _linknet("linknet_resnet34", model_path, **kwargs)
139
+
140
+
141
+ def linknet_resnet50(model_path: str = default_cfgs["linknet_resnet50"]["url"], **kwargs: Any) -> LinkNet:
142
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
143
+ <https://arxiv.org/pdf/1707.03718.pdf>`_.
144
+
145
+ >>> import numpy as np
146
+ >>> from onnxtr.models import linknet_resnet50
147
+ >>> model = linknet_resnet50()
148
+ >>> input_tensor = np.random.rand(1, 3, 1024, 1024)
149
+ >>> out = model(input_tensor)
150
+
151
+ Args:
152
+ ----
153
+ model_path: path to onnx model file, defaults to url in default_cfgs
154
+ **kwargs: keyword arguments of the LinkNet architecture
155
+
156
+ Returns:
157
+ -------
158
+ text detection architecture
159
+ """
160
+ return _linknet("linknet_resnet50", model_path, **kwargs)
File without changes