onnxtr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/__init__.py +2 -0
- onnxtr/contrib/__init__.py +0 -0
- onnxtr/contrib/artefacts.py +131 -0
- onnxtr/contrib/base.py +105 -0
- onnxtr/file_utils.py +33 -0
- onnxtr/io/__init__.py +5 -0
- onnxtr/io/elements.py +455 -0
- onnxtr/io/html.py +28 -0
- onnxtr/io/image.py +56 -0
- onnxtr/io/pdf.py +42 -0
- onnxtr/io/reader.py +85 -0
- onnxtr/models/__init__.py +4 -0
- onnxtr/models/_utils.py +141 -0
- onnxtr/models/builder.py +355 -0
- onnxtr/models/classification/__init__.py +2 -0
- onnxtr/models/classification/models/__init__.py +1 -0
- onnxtr/models/classification/models/mobilenet.py +120 -0
- onnxtr/models/classification/predictor/__init__.py +1 -0
- onnxtr/models/classification/predictor/base.py +57 -0
- onnxtr/models/classification/zoo.py +76 -0
- onnxtr/models/detection/__init__.py +2 -0
- onnxtr/models/detection/core.py +101 -0
- onnxtr/models/detection/models/__init__.py +3 -0
- onnxtr/models/detection/models/differentiable_binarization.py +159 -0
- onnxtr/models/detection/models/fast.py +160 -0
- onnxtr/models/detection/models/linknet.py +160 -0
- onnxtr/models/detection/postprocessor/__init__.py +0 -0
- onnxtr/models/detection/postprocessor/base.py +144 -0
- onnxtr/models/detection/predictor/__init__.py +1 -0
- onnxtr/models/detection/predictor/base.py +54 -0
- onnxtr/models/detection/zoo.py +73 -0
- onnxtr/models/engine.py +50 -0
- onnxtr/models/predictor/__init__.py +1 -0
- onnxtr/models/predictor/base.py +175 -0
- onnxtr/models/predictor/predictor.py +145 -0
- onnxtr/models/preprocessor/__init__.py +1 -0
- onnxtr/models/preprocessor/base.py +118 -0
- onnxtr/models/recognition/__init__.py +2 -0
- onnxtr/models/recognition/core.py +28 -0
- onnxtr/models/recognition/models/__init__.py +5 -0
- onnxtr/models/recognition/models/crnn.py +226 -0
- onnxtr/models/recognition/models/master.py +145 -0
- onnxtr/models/recognition/models/parseq.py +134 -0
- onnxtr/models/recognition/models/sar.py +134 -0
- onnxtr/models/recognition/models/vitstr.py +166 -0
- onnxtr/models/recognition/predictor/__init__.py +1 -0
- onnxtr/models/recognition/predictor/_utils.py +86 -0
- onnxtr/models/recognition/predictor/base.py +79 -0
- onnxtr/models/recognition/utils.py +89 -0
- onnxtr/models/recognition/zoo.py +69 -0
- onnxtr/models/zoo.py +114 -0
- onnxtr/transforms/__init__.py +1 -0
- onnxtr/transforms/base.py +112 -0
- onnxtr/utils/__init__.py +4 -0
- onnxtr/utils/common_types.py +18 -0
- onnxtr/utils/data.py +126 -0
- onnxtr/utils/fonts.py +41 -0
- onnxtr/utils/geometry.py +498 -0
- onnxtr/utils/multithreading.py +50 -0
- onnxtr/utils/reconstitution.py +70 -0
- onnxtr/utils/repr.py +64 -0
- onnxtr/utils/visualization.py +291 -0
- onnxtr/utils/vocabs.py +71 -0
- onnxtr/version.py +1 -0
- onnxtr-0.1.0.dist-info/LICENSE +201 -0
- onnxtr-0.1.0.dist-info/METADATA +481 -0
- onnxtr-0.1.0.dist-info/RECORD +70 -0
- onnxtr-0.1.0.dist-info/WHEEL +5 -0
- onnxtr-0.1.0.dist-info/top_level.txt +2 -0
- onnxtr-0.1.0.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, List
|
|
7
|
+
|
|
8
|
+
from .. import classification
|
|
9
|
+
from ..preprocessor import PreProcessor
|
|
10
|
+
from .predictor import OrientationPredictor
|
|
11
|
+
|
|
12
|
+
__all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
|
|
13
|
+
|
|
14
|
+
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _orientation_predictor(arch: str, **kwargs: Any) -> OrientationPredictor:
|
|
18
|
+
if arch not in ORIENTATION_ARCHS:
|
|
19
|
+
raise ValueError(f"unknown architecture '{arch}'")
|
|
20
|
+
|
|
21
|
+
# Load directly classifier from backbone
|
|
22
|
+
_model = classification.__dict__[arch]()
|
|
23
|
+
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
24
|
+
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
25
|
+
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
|
|
26
|
+
input_shape = _model.cfg["input_shape"][1:]
|
|
27
|
+
predictor = OrientationPredictor(
|
|
28
|
+
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
|
|
29
|
+
)
|
|
30
|
+
return predictor
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def crop_orientation_predictor(
|
|
34
|
+
arch: Any = "mobilenet_v3_small_crop_orientation", **kwargs: Any
|
|
35
|
+
) -> OrientationPredictor:
|
|
36
|
+
"""Crop orientation classification architecture.
|
|
37
|
+
|
|
38
|
+
>>> import numpy as np
|
|
39
|
+
>>> from onnxtr.models import crop_orientation_predictor
|
|
40
|
+
>>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation')
|
|
41
|
+
>>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8)
|
|
42
|
+
>>> out = model([input_crop])
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
----
|
|
46
|
+
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
|
|
47
|
+
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
-------
|
|
51
|
+
OrientationPredictor
|
|
52
|
+
"""
|
|
53
|
+
return _orientation_predictor(arch, **kwargs)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def page_orientation_predictor(
|
|
57
|
+
arch: Any = "mobilenet_v3_small_page_orientation", **kwargs: Any
|
|
58
|
+
) -> OrientationPredictor:
|
|
59
|
+
"""Page orientation classification architecture.
|
|
60
|
+
|
|
61
|
+
>>> import numpy as np
|
|
62
|
+
>>> from onnxtr.models import page_orientation_predictor
|
|
63
|
+
>>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation')
|
|
64
|
+
>>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8)
|
|
65
|
+
>>> out = model([input_page])
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
----
|
|
69
|
+
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
|
|
70
|
+
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
-------
|
|
74
|
+
OrientationPredictor
|
|
75
|
+
"""
|
|
76
|
+
return _orientation_predictor(arch, **kwargs)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from onnxtr.utils.repr import NestedObject
|
|
12
|
+
|
|
13
|
+
__all__ = ["DetectionPostProcessor"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DetectionPostProcessor(NestedObject):
|
|
17
|
+
"""Abstract class to postprocess the raw output of the model
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
----
|
|
21
|
+
box_thresh (float): minimal objectness score to consider a box
|
|
22
|
+
bin_thresh (float): threshold to apply to segmentation raw heatmap
|
|
23
|
+
assume straight_pages (bool): if True, fit straight boxes only
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, box_thresh: float = 0.5, bin_thresh: float = 0.5, assume_straight_pages: bool = True) -> None:
|
|
27
|
+
self.box_thresh = box_thresh
|
|
28
|
+
self.bin_thresh = bin_thresh
|
|
29
|
+
self.assume_straight_pages = assume_straight_pages
|
|
30
|
+
self._opening_kernel: np.ndarray = np.ones((3, 3), dtype=np.uint8)
|
|
31
|
+
|
|
32
|
+
def extra_repr(self) -> str:
|
|
33
|
+
return f"bin_thresh={self.bin_thresh}, box_thresh={self.box_thresh}"
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool = True) -> float:
|
|
37
|
+
"""Compute the confidence score for a polygon : mean of the p values on the polygon
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
----
|
|
41
|
+
pred (np.ndarray): p map returned by the model
|
|
42
|
+
points: coordinates of the polygon
|
|
43
|
+
assume_straight_pages: if True, fit straight boxes only
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
-------
|
|
47
|
+
polygon objectness
|
|
48
|
+
"""
|
|
49
|
+
h, w = pred.shape[:2]
|
|
50
|
+
|
|
51
|
+
if assume_straight_pages:
|
|
52
|
+
xmin = np.clip(np.floor(points[:, 0].min()).astype(np.int32), 0, w - 1)
|
|
53
|
+
xmax = np.clip(np.ceil(points[:, 0].max()).astype(np.int32), 0, w - 1)
|
|
54
|
+
ymin = np.clip(np.floor(points[:, 1].min()).astype(np.int32), 0, h - 1)
|
|
55
|
+
ymax = np.clip(np.ceil(points[:, 1].max()).astype(np.int32), 0, h - 1)
|
|
56
|
+
return pred[ymin : ymax + 1, xmin : xmax + 1].mean()
|
|
57
|
+
|
|
58
|
+
else:
|
|
59
|
+
mask: np.ndarray = np.zeros((h, w), np.int32)
|
|
60
|
+
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload]
|
|
61
|
+
product = pred * mask
|
|
62
|
+
return np.sum(product) / np.count_nonzero(product)
|
|
63
|
+
|
|
64
|
+
def bitmap_to_boxes(
|
|
65
|
+
self,
|
|
66
|
+
pred: np.ndarray,
|
|
67
|
+
bitmap: np.ndarray,
|
|
68
|
+
) -> np.ndarray:
|
|
69
|
+
raise NotImplementedError
|
|
70
|
+
|
|
71
|
+
def __call__(
|
|
72
|
+
self,
|
|
73
|
+
proba_map,
|
|
74
|
+
) -> List[List[np.ndarray]]:
|
|
75
|
+
"""Performs postprocessing for a list of model outputs
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
----
|
|
79
|
+
proba_map: probability map of shape (N, H, W, C)
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
-------
|
|
83
|
+
list of N class predictions (for each input sample), where each class predictions is a list of C tensors
|
|
84
|
+
of shape (*, 5) or (*, 6)
|
|
85
|
+
"""
|
|
86
|
+
if proba_map.ndim != 4:
|
|
87
|
+
raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.")
|
|
88
|
+
|
|
89
|
+
# Erosion + dilation on the binary map
|
|
90
|
+
bin_map = [
|
|
91
|
+
[
|
|
92
|
+
cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel)
|
|
93
|
+
for idx in range(proba_map.shape[-1])
|
|
94
|
+
]
|
|
95
|
+
for bmap in (proba_map >= self.bin_thresh).astype(np.uint8)
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
return [
|
|
99
|
+
[self.bitmap_to_boxes(pmaps[..., idx], bmaps[idx]) for idx in range(proba_map.shape[-1])]
|
|
100
|
+
for pmaps, bmaps in zip(proba_map, bin_map)
|
|
101
|
+
]
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.special import expit
|
|
10
|
+
|
|
11
|
+
from ...engine import Engine
|
|
12
|
+
from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
13
|
+
|
|
14
|
+
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
18
|
+
"db_resnet50": {
|
|
19
|
+
"input_shape": (3, 1024, 1024),
|
|
20
|
+
"mean": (0.798, 0.785, 0.772),
|
|
21
|
+
"std": (0.264, 0.2749, 0.287),
|
|
22
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet50-69ba0015.onnx",
|
|
23
|
+
},
|
|
24
|
+
"db_resnet34": {
|
|
25
|
+
"input_shape": (3, 1024, 1024),
|
|
26
|
+
"mean": (0.798, 0.785, 0.772),
|
|
27
|
+
"std": (0.264, 0.2749, 0.287),
|
|
28
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet34-b4873198.onnx",
|
|
29
|
+
},
|
|
30
|
+
"db_mobilenet_v3_large": {
|
|
31
|
+
"input_shape": (3, 1024, 1024),
|
|
32
|
+
"mean": (0.798, 0.785, 0.772),
|
|
33
|
+
"std": (0.264, 0.2749, 0.287),
|
|
34
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx",
|
|
35
|
+
},
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DBNet(Engine):
|
|
40
|
+
"""DBNet Onnx loader
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
----
|
|
44
|
+
model_path: path or url to onnx model file
|
|
45
|
+
bin_thresh: threshold for binarization of the output feature map
|
|
46
|
+
box_thresh: minimal objectness score to consider a box
|
|
47
|
+
assume_straight_pages: if True, fit straight bounding boxes only
|
|
48
|
+
cfg: the configuration dict of the model
|
|
49
|
+
**kwargs: additional arguments to be passed to `Engine`
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
model_path,
|
|
55
|
+
bin_thresh: float = 0.3,
|
|
56
|
+
box_thresh: float = 0.1,
|
|
57
|
+
assume_straight_pages: bool = True,
|
|
58
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
59
|
+
**kwargs: Any,
|
|
60
|
+
) -> None:
|
|
61
|
+
super().__init__(url=model_path, **kwargs)
|
|
62
|
+
self.cfg = cfg
|
|
63
|
+
self.assume_straight_pages = assume_straight_pages
|
|
64
|
+
self.postprocessor = GeneralDetectionPostProcessor(
|
|
65
|
+
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def __call__(
|
|
69
|
+
self,
|
|
70
|
+
x: np.ndarray,
|
|
71
|
+
return_model_output: bool = False,
|
|
72
|
+
**kwargs: Any,
|
|
73
|
+
) -> Dict[str, Any]:
|
|
74
|
+
logits = self.run(x)
|
|
75
|
+
|
|
76
|
+
out: Dict[str, Any] = {}
|
|
77
|
+
|
|
78
|
+
prob_map = expit(logits)
|
|
79
|
+
if return_model_output:
|
|
80
|
+
out["out_map"] = prob_map
|
|
81
|
+
|
|
82
|
+
out["preds"] = self.postprocessor(prob_map)
|
|
83
|
+
|
|
84
|
+
return out
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _dbnet(
|
|
88
|
+
arch: str,
|
|
89
|
+
model_path: str,
|
|
90
|
+
**kwargs: Any,
|
|
91
|
+
) -> DBNet:
|
|
92
|
+
# Build the model
|
|
93
|
+
return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def db_resnet34(model_path: str = default_cfgs["db_resnet34"]["url"], **kwargs: Any) -> DBNet:
|
|
97
|
+
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
98
|
+
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
|
|
99
|
+
|
|
100
|
+
>>> import numpy as np
|
|
101
|
+
>>> from onnxtr.models import db_resnet34
|
|
102
|
+
>>> model = db_resnet34()
|
|
103
|
+
>>> input_tensor = np.random.rand(1, 3, 1024, 1024)
|
|
104
|
+
>>> out = model(input_tensor)
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
----
|
|
108
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
109
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
-------
|
|
113
|
+
text detection architecture
|
|
114
|
+
"""
|
|
115
|
+
return _dbnet("db_resnet34", model_path, **kwargs)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def db_resnet50(model_path: str = default_cfgs["db_resnet50"]["url"], **kwargs: Any) -> DBNet:
|
|
119
|
+
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
120
|
+
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
|
|
121
|
+
|
|
122
|
+
>>> import numpy as np
|
|
123
|
+
>>> from onnxtr.models import db_resnet50
|
|
124
|
+
>>> model = db_resnet50()
|
|
125
|
+
>>> input_tensor = np.random.rand(1, 3, 1024, 1024)
|
|
126
|
+
>>> out = model(input_tensor)
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
----
|
|
130
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
131
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
-------
|
|
135
|
+
text detection architecture
|
|
136
|
+
"""
|
|
137
|
+
return _dbnet("db_resnet50", model_path, **kwargs)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], **kwargs: Any) -> DBNet:
|
|
141
|
+
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
142
|
+
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
|
|
143
|
+
|
|
144
|
+
>>> import numpy as np
|
|
145
|
+
>>> from onnxtr.models import db_mobilenet_v3_large
|
|
146
|
+
>>> model = db_mobilenet_v3_large()
|
|
147
|
+
>>> input_tensor = np.random.rand(1, 3, 1024, 1024)
|
|
148
|
+
>>> out = model(input_tensor)
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
----
|
|
152
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
153
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
-------
|
|
157
|
+
text detection architecture
|
|
158
|
+
"""
|
|
159
|
+
return _dbnet("db_mobilenet_v3_large", model_path, **kwargs)
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.special import expit
|
|
10
|
+
|
|
11
|
+
from ...engine import Engine
|
|
12
|
+
from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
13
|
+
|
|
14
|
+
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
18
|
+
"fast_tiny": {
|
|
19
|
+
"input_shape": (3, 1024, 1024),
|
|
20
|
+
"mean": (0.798, 0.785, 0.772),
|
|
21
|
+
"std": (0.264, 0.2749, 0.287),
|
|
22
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/rep_fast_tiny-28867779.onnx",
|
|
23
|
+
},
|
|
24
|
+
"fast_small": {
|
|
25
|
+
"input_shape": (3, 1024, 1024),
|
|
26
|
+
"mean": (0.798, 0.785, 0.772),
|
|
27
|
+
"std": (0.264, 0.2749, 0.287),
|
|
28
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/rep_fast_small-10428b70.onnx",
|
|
29
|
+
},
|
|
30
|
+
"fast_base": {
|
|
31
|
+
"input_shape": (3, 1024, 1024),
|
|
32
|
+
"mean": (0.798, 0.785, 0.772),
|
|
33
|
+
"std": (0.264, 0.2749, 0.287),
|
|
34
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/rep_fast_base-1b89ebf9.onnx",
|
|
35
|
+
},
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class FAST(Engine):
|
|
40
|
+
"""FAST Onnx loader
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
----
|
|
44
|
+
model_path: path or url to onnx model file
|
|
45
|
+
bin_thresh: threshold for binarization of the output feature map
|
|
46
|
+
box_thresh: minimal objectness score to consider a box
|
|
47
|
+
assume_straight_pages: if True, fit straight bounding boxes only
|
|
48
|
+
cfg: the configuration dict of the model
|
|
49
|
+
**kwargs: additional arguments to be passed to `Engine`
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
model_path: str,
|
|
55
|
+
bin_thresh: float = 0.1,
|
|
56
|
+
box_thresh: float = 0.1,
|
|
57
|
+
assume_straight_pages: bool = True,
|
|
58
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
59
|
+
**kwargs: Any,
|
|
60
|
+
) -> None:
|
|
61
|
+
super().__init__(url=model_path, **kwargs)
|
|
62
|
+
self.cfg = cfg
|
|
63
|
+
self.assume_straight_pages = assume_straight_pages
|
|
64
|
+
|
|
65
|
+
self.postprocessor = GeneralDetectionPostProcessor(
|
|
66
|
+
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def __call__(
|
|
70
|
+
self,
|
|
71
|
+
x: np.ndarray,
|
|
72
|
+
return_model_output: bool = False,
|
|
73
|
+
**kwargs: Any,
|
|
74
|
+
) -> Dict[str, Any]:
|
|
75
|
+
logits = self.run(x)
|
|
76
|
+
|
|
77
|
+
out: Dict[str, Any] = {}
|
|
78
|
+
|
|
79
|
+
prob_map = expit(logits)
|
|
80
|
+
if return_model_output:
|
|
81
|
+
out["out_map"] = prob_map
|
|
82
|
+
|
|
83
|
+
out["preds"] = self.postprocessor(prob_map)
|
|
84
|
+
|
|
85
|
+
return out
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _fast(
|
|
89
|
+
arch: str,
|
|
90
|
+
model_path: str,
|
|
91
|
+
**kwargs: Any,
|
|
92
|
+
) -> FAST:
|
|
93
|
+
# Build the model
|
|
94
|
+
return FAST(model_path, cfg=default_cfgs[arch], **kwargs)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], **kwargs: Any) -> FAST:
|
|
98
|
+
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
99
|
+
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
|
|
100
|
+
|
|
101
|
+
>>> import numpy as np
|
|
102
|
+
>>> from onnxtr.models import fast_tiny
|
|
103
|
+
>>> model = fast_tiny()
|
|
104
|
+
>>> input_tensor = np.random.rand(1, 3, 1024, 1024)
|
|
105
|
+
>>> out = model(input_tensor)
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
----
|
|
109
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
110
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
-------
|
|
114
|
+
text detection architecture
|
|
115
|
+
"""
|
|
116
|
+
return _fast("fast_tiny", model_path, **kwargs)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def fast_small(model_path: str = default_cfgs["fast_small"]["url"], **kwargs: Any) -> FAST:
|
|
120
|
+
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
121
|
+
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
|
|
122
|
+
|
|
123
|
+
>>> import numpy as np
|
|
124
|
+
>>> from onnxtr.models import fast_small
|
|
125
|
+
>>> model = fast_small()
|
|
126
|
+
>>> input_tensor = np.random.rand(1, 3, 1024, 1024)
|
|
127
|
+
>>> out = model(input_tensor)
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
----
|
|
131
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
132
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
-------
|
|
136
|
+
text detection architecture
|
|
137
|
+
"""
|
|
138
|
+
return _fast("fast_small", model_path, **kwargs)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def fast_base(model_path: str = default_cfgs["fast_base"]["url"], **kwargs: Any) -> FAST:
|
|
142
|
+
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
143
|
+
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
|
|
144
|
+
|
|
145
|
+
>>> import numpy as np
|
|
146
|
+
>>> from onnxtr.models import fast_base
|
|
147
|
+
>>> model = fast_base()
|
|
148
|
+
>>> input_tensor = np.random.rand(1, 3, 1024, 1024)
|
|
149
|
+
>>> out = model(input_tensor)
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
----
|
|
153
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
154
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
-------
|
|
158
|
+
text detection architecture
|
|
159
|
+
"""
|
|
160
|
+
return _fast("fast_base", model_path, **kwargs)
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.special import expit
|
|
10
|
+
|
|
11
|
+
from ...engine import Engine
|
|
12
|
+
from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
13
|
+
|
|
14
|
+
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
18
|
+
"linknet_resnet18": {
|
|
19
|
+
"input_shape": (3, 1024, 1024),
|
|
20
|
+
"mean": (0.798, 0.785, 0.772),
|
|
21
|
+
"std": (0.264, 0.2749, 0.287),
|
|
22
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet18-e0e0b9dc.onnx",
|
|
23
|
+
},
|
|
24
|
+
"linknet_resnet34": {
|
|
25
|
+
"input_shape": (3, 1024, 1024),
|
|
26
|
+
"mean": (0.798, 0.785, 0.772),
|
|
27
|
+
"std": (0.264, 0.2749, 0.287),
|
|
28
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet34-93e39a39.onnx",
|
|
29
|
+
},
|
|
30
|
+
"linknet_resnet50": {
|
|
31
|
+
"input_shape": (3, 1024, 1024),
|
|
32
|
+
"mean": (0.798, 0.785, 0.772),
|
|
33
|
+
"std": (0.264, 0.2749, 0.287),
|
|
34
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet50-15d8c4ec.onnx",
|
|
35
|
+
},
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LinkNet(Engine):
|
|
40
|
+
"""LinkNet Onnx loader
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
----
|
|
44
|
+
model_path: path or url to onnx model file
|
|
45
|
+
bin_thresh: threshold for binarization of the output feature map
|
|
46
|
+
box_thresh: minimal objectness score to consider a box
|
|
47
|
+
assume_straight_pages: if True, fit straight bounding boxes only
|
|
48
|
+
cfg: the configuration dict of the model
|
|
49
|
+
**kwargs: additional arguments to be passed to `Engine`
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
model_path: str,
|
|
55
|
+
bin_thresh: float = 0.1,
|
|
56
|
+
box_thresh: float = 0.1,
|
|
57
|
+
assume_straight_pages: bool = True,
|
|
58
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
59
|
+
**kwargs: Any,
|
|
60
|
+
) -> None:
|
|
61
|
+
super().__init__(url=model_path, **kwargs)
|
|
62
|
+
self.cfg = cfg
|
|
63
|
+
self.assume_straight_pages = assume_straight_pages
|
|
64
|
+
|
|
65
|
+
self.postprocessor = GeneralDetectionPostProcessor(
|
|
66
|
+
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def __call__(
|
|
70
|
+
self,
|
|
71
|
+
x: np.ndarray,
|
|
72
|
+
return_model_output: bool = False,
|
|
73
|
+
**kwargs: Any,
|
|
74
|
+
) -> Dict[str, Any]:
|
|
75
|
+
logits = self.run(x)
|
|
76
|
+
|
|
77
|
+
out: Dict[str, Any] = {}
|
|
78
|
+
|
|
79
|
+
prob_map = expit(logits)
|
|
80
|
+
if return_model_output:
|
|
81
|
+
out["out_map"] = prob_map
|
|
82
|
+
|
|
83
|
+
out["preds"] = self.postprocessor(prob_map)
|
|
84
|
+
|
|
85
|
+
return out
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _linknet(
|
|
89
|
+
arch: str,
|
|
90
|
+
model_path: str,
|
|
91
|
+
**kwargs: Any,
|
|
92
|
+
) -> LinkNet:
|
|
93
|
+
# Build the model
|
|
94
|
+
return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def linknet_resnet18(model_path: str = default_cfgs["linknet_resnet18"]["url"], **kwargs: Any) -> LinkNet:
|
|
98
|
+
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
99
|
+
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
100
|
+
|
|
101
|
+
>>> import numpy as np
|
|
102
|
+
>>> from onnxtr.models import linknet_resnet18
|
|
103
|
+
>>> model = linknet_resnet18()
|
|
104
|
+
>>> input_tensor = np.random.rand(1, 3, 1024, 1024)
|
|
105
|
+
>>> out = model(input_tensor)
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
----
|
|
109
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
110
|
+
**kwargs: keyword arguments of the LinkNet architecture
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
-------
|
|
114
|
+
text detection architecture
|
|
115
|
+
"""
|
|
116
|
+
return _linknet("linknet_resnet18", model_path, **kwargs)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def linknet_resnet34(model_path: str = default_cfgs["linknet_resnet34"]["url"], **kwargs: Any) -> LinkNet:
|
|
120
|
+
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
121
|
+
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
122
|
+
|
|
123
|
+
>>> import numpy as np
|
|
124
|
+
>>> from onnxtr.models import linknet_resnet34
|
|
125
|
+
>>> model = linknet_resnet34()
|
|
126
|
+
>>> input_tensor = np.random.rand(1, 3, 1024, 1024)
|
|
127
|
+
>>> out = model(input_tensor)
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
----
|
|
131
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
132
|
+
**kwargs: keyword arguments of the LinkNet architecture
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
-------
|
|
136
|
+
text detection architecture
|
|
137
|
+
"""
|
|
138
|
+
return _linknet("linknet_resnet34", model_path, **kwargs)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def linknet_resnet50(model_path: str = default_cfgs["linknet_resnet50"]["url"], **kwargs: Any) -> LinkNet:
|
|
142
|
+
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
143
|
+
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
144
|
+
|
|
145
|
+
>>> import numpy as np
|
|
146
|
+
>>> from onnxtr.models import linknet_resnet50
|
|
147
|
+
>>> model = linknet_resnet50()
|
|
148
|
+
>>> input_tensor = np.random.rand(1, 3, 1024, 1024)
|
|
149
|
+
>>> out = model(input_tensor)
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
----
|
|
153
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
154
|
+
**kwargs: keyword arguments of the LinkNet architecture
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
-------
|
|
158
|
+
text detection architecture
|
|
159
|
+
"""
|
|
160
|
+
return _linknet("linknet_resnet50", model_path, **kwargs)
|
|
File without changes
|