onnxtr 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/io/elements.py +17 -4
- onnxtr/io/pdf.py +6 -3
- onnxtr/models/__init__.py +1 -0
- onnxtr/models/_utils.py +57 -20
- onnxtr/models/builder.py +24 -9
- onnxtr/models/classification/models/mobilenet.py +12 -5
- onnxtr/models/classification/zoo.py +18 -6
- onnxtr/models/detection/_utils/__init__.py +1 -0
- onnxtr/models/detection/_utils/base.py +66 -0
- onnxtr/models/detection/models/differentiable_binarization.py +27 -12
- onnxtr/models/detection/models/fast.py +30 -9
- onnxtr/models/detection/models/linknet.py +24 -9
- onnxtr/models/detection/postprocessor/base.py +4 -3
- onnxtr/models/detection/predictor/base.py +15 -1
- onnxtr/models/detection/zoo.py +12 -3
- onnxtr/models/engine.py +73 -7
- onnxtr/models/predictor/base.py +65 -42
- onnxtr/models/predictor/predictor.py +22 -15
- onnxtr/models/recognition/models/crnn.py +24 -9
- onnxtr/models/recognition/models/master.py +14 -5
- onnxtr/models/recognition/models/parseq.py +14 -5
- onnxtr/models/recognition/models/sar.py +12 -5
- onnxtr/models/recognition/models/vitstr.py +18 -7
- onnxtr/models/recognition/zoo.py +9 -6
- onnxtr/models/zoo.py +16 -0
- onnxtr/py.typed +0 -0
- onnxtr/utils/geometry.py +33 -12
- onnxtr/version.py +1 -1
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/METADATA +60 -21
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/RECORD +34 -31
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/WHEEL +1 -1
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/top_level.txt +0 -1
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/LICENSE +0 -0
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/zip-safe +0 -0
onnxtr/models/predictor/base.py
CHANGED
|
@@ -8,10 +8,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
from onnxtr.models.builder import DocumentBuilder
|
|
11
|
-
from onnxtr.
|
|
11
|
+
from onnxtr.models.engine import EngineConfig
|
|
12
|
+
from onnxtr.utils.geometry import extract_crops, extract_rcrops, rotate_image
|
|
12
13
|
|
|
13
|
-
from .._utils import rectify_crops, rectify_loc_preds
|
|
14
|
-
from ..classification import crop_orientation_predictor
|
|
14
|
+
from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
|
|
15
|
+
from ..classification import crop_orientation_predictor, page_orientation_predictor
|
|
15
16
|
from ..classification.predictor import OrientationPredictor
|
|
16
17
|
from ..detection.zoo import ARCHS as DETECTION_ARCHS
|
|
17
18
|
from ..recognition.zoo import ARCHS as RECOGNITION_ARCHS
|
|
@@ -31,11 +32,15 @@ class _OCRPredictor:
|
|
|
31
32
|
accordingly. Doing so will improve performances for documents with page-uniform rotations.
|
|
32
33
|
preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
|
|
33
34
|
symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
|
|
35
|
+
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
|
|
36
|
+
page. Doing so will slightly deteriorate the overall latency.
|
|
34
37
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
38
|
+
clf_engine_cfg: configuration of the orientation classification engine
|
|
35
39
|
**kwargs: keyword args of `DocumentBuilder`
|
|
36
40
|
"""
|
|
37
41
|
|
|
38
42
|
crop_orientation_predictor: Optional[OrientationPredictor]
|
|
43
|
+
page_orientation_predictor: Optional[OrientationPredictor]
|
|
39
44
|
|
|
40
45
|
def __init__(
|
|
41
46
|
self,
|
|
@@ -43,19 +48,75 @@ class _OCRPredictor:
|
|
|
43
48
|
straighten_pages: bool = False,
|
|
44
49
|
preserve_aspect_ratio: bool = True,
|
|
45
50
|
symmetric_pad: bool = True,
|
|
51
|
+
detect_orientation: bool = False,
|
|
46
52
|
load_in_8_bit: bool = False,
|
|
53
|
+
clf_engine_cfg: EngineConfig = EngineConfig(),
|
|
47
54
|
**kwargs: Any,
|
|
48
55
|
) -> None:
|
|
49
56
|
self.assume_straight_pages = assume_straight_pages
|
|
50
57
|
self.straighten_pages = straighten_pages
|
|
51
58
|
self.crop_orientation_predictor = (
|
|
52
|
-
None
|
|
59
|
+
None
|
|
60
|
+
if assume_straight_pages
|
|
61
|
+
else crop_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
|
|
62
|
+
)
|
|
63
|
+
self.page_orientation_predictor = (
|
|
64
|
+
page_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
|
|
65
|
+
if detect_orientation or straighten_pages or not assume_straight_pages
|
|
66
|
+
else None
|
|
53
67
|
)
|
|
54
68
|
self.doc_builder = DocumentBuilder(**kwargs)
|
|
55
69
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
56
70
|
self.symmetric_pad = symmetric_pad
|
|
57
71
|
self.hooks: List[Callable] = []
|
|
58
72
|
|
|
73
|
+
def _general_page_orientations(
|
|
74
|
+
self,
|
|
75
|
+
pages: List[np.ndarray],
|
|
76
|
+
) -> List[Tuple[int, float]]:
|
|
77
|
+
_, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
|
|
78
|
+
# Flatten to list of tuples with (value, confidence)
|
|
79
|
+
page_orientations = [
|
|
80
|
+
(orientation, prob)
|
|
81
|
+
for page_classes, page_probs in zip(classes, probs)
|
|
82
|
+
for orientation, prob in zip(page_classes, page_probs)
|
|
83
|
+
]
|
|
84
|
+
return page_orientations
|
|
85
|
+
|
|
86
|
+
def _get_orientations(
|
|
87
|
+
self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
|
|
88
|
+
) -> Tuple[List[Tuple[int, float]], List[int]]:
|
|
89
|
+
general_pages_orientations = self._general_page_orientations(pages)
|
|
90
|
+
origin_page_orientations = [
|
|
91
|
+
estimate_orientation(seq_map, general_orientation)
|
|
92
|
+
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
|
|
93
|
+
]
|
|
94
|
+
return general_pages_orientations, origin_page_orientations
|
|
95
|
+
|
|
96
|
+
def _straighten_pages(
|
|
97
|
+
self,
|
|
98
|
+
pages: List[np.ndarray],
|
|
99
|
+
seg_maps: List[np.ndarray],
|
|
100
|
+
general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
|
|
101
|
+
origin_pages_orientations: Optional[List[int]] = None,
|
|
102
|
+
) -> List[np.ndarray]:
|
|
103
|
+
general_pages_orientations = (
|
|
104
|
+
general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
|
|
105
|
+
)
|
|
106
|
+
origin_pages_orientations = (
|
|
107
|
+
origin_pages_orientations
|
|
108
|
+
if origin_pages_orientations
|
|
109
|
+
else [
|
|
110
|
+
estimate_orientation(seq_map, general_orientation)
|
|
111
|
+
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
|
|
112
|
+
]
|
|
113
|
+
)
|
|
114
|
+
return [
|
|
115
|
+
# We exapnd if the page is wider than tall and the angle is 90 or -90
|
|
116
|
+
rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90)
|
|
117
|
+
for page, angle in zip(pages, origin_pages_orientations)
|
|
118
|
+
]
|
|
119
|
+
|
|
59
120
|
@staticmethod
|
|
60
121
|
def _generate_crops(
|
|
61
122
|
pages: List[np.ndarray],
|
|
@@ -110,44 +171,6 @@ class _OCRPredictor:
|
|
|
110
171
|
]
|
|
111
172
|
return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
|
|
112
173
|
|
|
113
|
-
def _remove_padding(
|
|
114
|
-
self,
|
|
115
|
-
pages: List[np.ndarray],
|
|
116
|
-
loc_preds: List[np.ndarray],
|
|
117
|
-
) -> List[np.ndarray]:
|
|
118
|
-
if self.preserve_aspect_ratio:
|
|
119
|
-
# Rectify loc_preds to remove padding
|
|
120
|
-
rectified_preds = []
|
|
121
|
-
for page, loc_pred in zip(pages, loc_preds):
|
|
122
|
-
h, w = page.shape[0], page.shape[1]
|
|
123
|
-
if h > w:
|
|
124
|
-
# y unchanged, dilate x coord
|
|
125
|
-
if self.symmetric_pad:
|
|
126
|
-
if self.assume_straight_pages:
|
|
127
|
-
loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
|
|
128
|
-
else:
|
|
129
|
-
loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
|
|
130
|
-
else:
|
|
131
|
-
if self.assume_straight_pages:
|
|
132
|
-
loc_pred[:, [0, 2]] *= h / w
|
|
133
|
-
else:
|
|
134
|
-
loc_pred[:, :, 0] *= h / w
|
|
135
|
-
elif w > h:
|
|
136
|
-
# x unchanged, dilate y coord
|
|
137
|
-
if self.symmetric_pad:
|
|
138
|
-
if self.assume_straight_pages:
|
|
139
|
-
loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
|
|
140
|
-
else:
|
|
141
|
-
loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
|
|
142
|
-
else:
|
|
143
|
-
if self.assume_straight_pages:
|
|
144
|
-
loc_pred[:, [1, 3]] *= w / h
|
|
145
|
-
else:
|
|
146
|
-
loc_pred[:, :, 1] *= w / h
|
|
147
|
-
rectified_preds.append(loc_pred)
|
|
148
|
-
return rectified_preds
|
|
149
|
-
return loc_preds
|
|
150
|
-
|
|
151
174
|
@staticmethod
|
|
152
175
|
def _process_predictions(
|
|
153
176
|
loc_preds: List[np.ndarray],
|
|
@@ -8,10 +8,11 @@ from typing import Any, List
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
from onnxtr.io.elements import Document
|
|
11
|
-
from onnxtr.models._utils import
|
|
11
|
+
from onnxtr.models._utils import get_language
|
|
12
12
|
from onnxtr.models.detection.predictor import DetectionPredictor
|
|
13
|
+
from onnxtr.models.engine import EngineConfig
|
|
13
14
|
from onnxtr.models.recognition.predictor import RecognitionPredictor
|
|
14
|
-
from onnxtr.utils.geometry import
|
|
15
|
+
from onnxtr.utils.geometry import detach_scores
|
|
15
16
|
from onnxtr.utils.repr import NestedObject
|
|
16
17
|
|
|
17
18
|
from .base import _OCRPredictor
|
|
@@ -35,6 +36,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
35
36
|
page. Doing so will slightly deteriorate the overall latency.
|
|
36
37
|
detect_language: if True, the language prediction will be added to the predictions for each
|
|
37
38
|
page. Doing so will slightly deteriorate the overall latency.
|
|
39
|
+
clf_engine_cfg: configuration of the orientation classification engine
|
|
38
40
|
**kwargs: keyword args of `DocumentBuilder`
|
|
39
41
|
"""
|
|
40
42
|
|
|
@@ -50,12 +52,20 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
50
52
|
symmetric_pad: bool = True,
|
|
51
53
|
detect_orientation: bool = False,
|
|
52
54
|
detect_language: bool = False,
|
|
55
|
+
clf_engine_cfg: EngineConfig = EngineConfig(),
|
|
53
56
|
**kwargs: Any,
|
|
54
57
|
) -> None:
|
|
55
58
|
self.det_predictor = det_predictor
|
|
56
59
|
self.reco_predictor = reco_predictor
|
|
57
60
|
_OCRPredictor.__init__(
|
|
58
|
-
self,
|
|
61
|
+
self,
|
|
62
|
+
assume_straight_pages,
|
|
63
|
+
straighten_pages,
|
|
64
|
+
preserve_aspect_ratio,
|
|
65
|
+
symmetric_pad,
|
|
66
|
+
detect_orientation,
|
|
67
|
+
clf_engine_cfg=clf_engine_cfg,
|
|
68
|
+
**kwargs,
|
|
59
69
|
)
|
|
60
70
|
self.detect_orientation = detect_orientation
|
|
61
71
|
self.detect_language = detect_language
|
|
@@ -80,26 +90,22 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
80
90
|
for out_map in out_maps
|
|
81
91
|
]
|
|
82
92
|
if self.detect_orientation:
|
|
83
|
-
|
|
93
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
84
94
|
orientations = [
|
|
85
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
95
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
86
96
|
]
|
|
87
97
|
else:
|
|
88
98
|
orientations = None
|
|
99
|
+
general_pages_orientations = None
|
|
100
|
+
origin_pages_orientations = None
|
|
89
101
|
if self.straighten_pages:
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
if self.detect_orientation
|
|
93
|
-
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
94
|
-
)
|
|
95
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
102
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
103
|
+
|
|
96
104
|
# forward again to get predictions on straight pages
|
|
97
105
|
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
98
106
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
# Rectify crops if aspect ratio
|
|
102
|
-
loc_preds = self._remove_padding(pages, loc_preds)
|
|
107
|
+
# Detach objectness scores from loc_preds
|
|
108
|
+
loc_preds, objectness_scores = detach_scores(loc_preds) # type: ignore[arg-type]
|
|
103
109
|
|
|
104
110
|
# Apply hooks to loc_preds if any
|
|
105
111
|
for hook in self.hooks:
|
|
@@ -136,6 +142,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
136
142
|
out = self.doc_builder(
|
|
137
143
|
pages,
|
|
138
144
|
boxes,
|
|
145
|
+
objectness_scores,
|
|
139
146
|
text_preds,
|
|
140
147
|
origin_page_shapes, # type: ignore[arg-type]
|
|
141
148
|
crop_orientations,
|
|
@@ -12,7 +12,7 @@ from scipy.special import softmax
|
|
|
12
12
|
|
|
13
13
|
from onnxtr.utils import VOCABS
|
|
14
14
|
|
|
15
|
-
from ...engine import Engine
|
|
15
|
+
from ...engine import Engine, EngineConfig
|
|
16
16
|
from ..core import RecognitionPostProcessor
|
|
17
17
|
|
|
18
18
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
@@ -113,6 +113,7 @@ class CRNN(Engine):
|
|
|
113
113
|
----
|
|
114
114
|
model_path: path or url to onnx model file
|
|
115
115
|
vocab: vocabulary used for encoding
|
|
116
|
+
engine_cfg: configuration for the inference engine
|
|
116
117
|
cfg: configuration dictionary
|
|
117
118
|
**kwargs: additional arguments to be passed to `Engine`
|
|
118
119
|
"""
|
|
@@ -123,10 +124,11 @@ class CRNN(Engine):
|
|
|
123
124
|
self,
|
|
124
125
|
model_path: str,
|
|
125
126
|
vocab: str,
|
|
127
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
126
128
|
cfg: Optional[Dict[str, Any]] = None,
|
|
127
129
|
**kwargs: Any,
|
|
128
130
|
) -> None:
|
|
129
|
-
super().__init__(url=model_path, **kwargs)
|
|
131
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
130
132
|
self.vocab = vocab
|
|
131
133
|
self.cfg = cfg
|
|
132
134
|
self.postprocessor = CRNNPostProcessor(self.vocab)
|
|
@@ -152,6 +154,7 @@ def _crnn(
|
|
|
152
154
|
arch: str,
|
|
153
155
|
model_path: str,
|
|
154
156
|
load_in_8_bit: bool = False,
|
|
157
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
155
158
|
**kwargs: Any,
|
|
156
159
|
) -> CRNN:
|
|
157
160
|
kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
|
|
@@ -163,11 +166,14 @@ def _crnn(
|
|
|
163
166
|
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
164
167
|
|
|
165
168
|
# Build the model
|
|
166
|
-
return CRNN(model_path, cfg=_cfg, **kwargs)
|
|
169
|
+
return CRNN(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
167
170
|
|
|
168
171
|
|
|
169
172
|
def crnn_vgg16_bn(
|
|
170
|
-
model_path: str = default_cfgs["crnn_vgg16_bn"]["url"],
|
|
173
|
+
model_path: str = default_cfgs["crnn_vgg16_bn"]["url"],
|
|
174
|
+
load_in_8_bit: bool = False,
|
|
175
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
176
|
+
**kwargs: Any,
|
|
171
177
|
) -> CRNN:
|
|
172
178
|
"""CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
173
179
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
@@ -182,17 +188,21 @@ def crnn_vgg16_bn(
|
|
|
182
188
|
----
|
|
183
189
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
184
190
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
191
|
+
engine_cfg: configuration for the inference engine
|
|
185
192
|
**kwargs: keyword arguments of the CRNN architecture
|
|
186
193
|
|
|
187
194
|
Returns:
|
|
188
195
|
-------
|
|
189
196
|
text recognition architecture
|
|
190
197
|
"""
|
|
191
|
-
return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, **kwargs)
|
|
198
|
+
return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
192
199
|
|
|
193
200
|
|
|
194
201
|
def crnn_mobilenet_v3_small(
|
|
195
|
-
model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"],
|
|
202
|
+
model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"],
|
|
203
|
+
load_in_8_bit: bool = False,
|
|
204
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
205
|
+
**kwargs: Any,
|
|
196
206
|
) -> CRNN:
|
|
197
207
|
"""CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
198
208
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
@@ -207,17 +217,21 @@ def crnn_mobilenet_v3_small(
|
|
|
207
217
|
----
|
|
208
218
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
209
219
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
220
|
+
engine_cfg: configuration for the inference engine
|
|
210
221
|
**kwargs: keyword arguments of the CRNN architecture
|
|
211
222
|
|
|
212
223
|
Returns:
|
|
213
224
|
-------
|
|
214
225
|
text recognition architecture
|
|
215
226
|
"""
|
|
216
|
-
return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, **kwargs)
|
|
227
|
+
return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
217
228
|
|
|
218
229
|
|
|
219
230
|
def crnn_mobilenet_v3_large(
|
|
220
|
-
model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"],
|
|
231
|
+
model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"],
|
|
232
|
+
load_in_8_bit: bool = False,
|
|
233
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
234
|
+
**kwargs: Any,
|
|
221
235
|
) -> CRNN:
|
|
222
236
|
"""CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
223
237
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
@@ -232,10 +246,11 @@ def crnn_mobilenet_v3_large(
|
|
|
232
246
|
----
|
|
233
247
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
234
248
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
249
|
+
engine_cfg: configuration for the inference engine
|
|
235
250
|
**kwargs: keyword arguments of the CRNN architecture
|
|
236
251
|
|
|
237
252
|
Returns:
|
|
238
253
|
-------
|
|
239
254
|
text recognition architecture
|
|
240
255
|
"""
|
|
241
|
-
return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
|
|
256
|
+
return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -11,7 +11,7 @@ from scipy.special import softmax
|
|
|
11
11
|
|
|
12
12
|
from onnxtr.utils import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...engine import Engine
|
|
14
|
+
from ...engine import Engine, EngineConfig
|
|
15
15
|
from ..core import RecognitionPostProcessor
|
|
16
16
|
|
|
17
17
|
__all__ = ["MASTER", "master"]
|
|
@@ -36,6 +36,7 @@ class MASTER(Engine):
|
|
|
36
36
|
----
|
|
37
37
|
model_path: path or url to onnx model file
|
|
38
38
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
39
|
+
engine_cfg: configuration for the inference engine
|
|
39
40
|
cfg: dictionary containing information about the model
|
|
40
41
|
**kwargs: additional arguments to be passed to `Engine`
|
|
41
42
|
"""
|
|
@@ -44,10 +45,11 @@ class MASTER(Engine):
|
|
|
44
45
|
self,
|
|
45
46
|
model_path: str,
|
|
46
47
|
vocab: str,
|
|
48
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
47
49
|
cfg: Optional[Dict[str, Any]] = None,
|
|
48
50
|
**kwargs: Any,
|
|
49
51
|
) -> None:
|
|
50
|
-
super().__init__(url=model_path, **kwargs)
|
|
52
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
51
53
|
|
|
52
54
|
self.vocab = vocab
|
|
53
55
|
self.cfg = cfg
|
|
@@ -114,6 +116,7 @@ def _master(
|
|
|
114
116
|
arch: str,
|
|
115
117
|
model_path: str,
|
|
116
118
|
load_in_8_bit: bool = False,
|
|
119
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
117
120
|
**kwargs: Any,
|
|
118
121
|
) -> MASTER:
|
|
119
122
|
# Patch the config
|
|
@@ -125,10 +128,15 @@ def _master(
|
|
|
125
128
|
# Patch the url
|
|
126
129
|
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
127
130
|
|
|
128
|
-
return MASTER(model_path, cfg=_cfg, **kwargs)
|
|
131
|
+
return MASTER(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
129
132
|
|
|
130
133
|
|
|
131
|
-
def master(
|
|
134
|
+
def master(
|
|
135
|
+
model_path: str = default_cfgs["master"]["url"],
|
|
136
|
+
load_in_8_bit: bool = False,
|
|
137
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
138
|
+
**kwargs: Any,
|
|
139
|
+
) -> MASTER:
|
|
132
140
|
"""MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
133
141
|
|
|
134
142
|
>>> import numpy as np
|
|
@@ -141,10 +149,11 @@ def master(model_path: str = default_cfgs["master"]["url"], load_in_8_bit: bool
|
|
|
141
149
|
----
|
|
142
150
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
143
151
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
152
|
+
engine_cfg: configuration for the inference engine
|
|
144
153
|
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
145
154
|
|
|
146
155
|
Returns:
|
|
147
156
|
-------
|
|
148
157
|
text recognition architecture
|
|
149
158
|
"""
|
|
150
|
-
return _master("master", model_path, load_in_8_bit, **kwargs)
|
|
159
|
+
return _master("master", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -11,7 +11,7 @@ from scipy.special import softmax
|
|
|
11
11
|
|
|
12
12
|
from onnxtr.utils import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...engine import Engine
|
|
14
|
+
from ...engine import Engine, EngineConfig
|
|
15
15
|
from ..core import RecognitionPostProcessor
|
|
16
16
|
|
|
17
17
|
__all__ = ["PARSeq", "parseq"]
|
|
@@ -35,6 +35,7 @@ class PARSeq(Engine):
|
|
|
35
35
|
----
|
|
36
36
|
model_path: path to onnx model file
|
|
37
37
|
vocab: vocabulary used for encoding
|
|
38
|
+
engine_cfg: configuration for the inference engine
|
|
38
39
|
cfg: dictionary containing information about the model
|
|
39
40
|
**kwargs: additional arguments to be passed to `Engine`
|
|
40
41
|
"""
|
|
@@ -43,10 +44,11 @@ class PARSeq(Engine):
|
|
|
43
44
|
self,
|
|
44
45
|
model_path: str,
|
|
45
46
|
vocab: str,
|
|
47
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
46
48
|
cfg: Optional[Dict[str, Any]] = None,
|
|
47
49
|
**kwargs: Any,
|
|
48
50
|
) -> None:
|
|
49
|
-
super().__init__(url=model_path, **kwargs)
|
|
51
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
50
52
|
self.vocab = vocab
|
|
51
53
|
self.cfg = cfg
|
|
52
54
|
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
|
|
@@ -102,6 +104,7 @@ def _parseq(
|
|
|
102
104
|
arch: str,
|
|
103
105
|
model_path: str,
|
|
104
106
|
load_in_8_bit: bool = False,
|
|
107
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
105
108
|
**kwargs: Any,
|
|
106
109
|
) -> PARSeq:
|
|
107
110
|
# Patch the config
|
|
@@ -114,10 +117,15 @@ def _parseq(
|
|
|
114
117
|
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
115
118
|
|
|
116
119
|
# Build the model
|
|
117
|
-
return PARSeq(model_path, cfg=_cfg, **kwargs)
|
|
120
|
+
return PARSeq(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
118
121
|
|
|
119
122
|
|
|
120
|
-
def parseq(
|
|
123
|
+
def parseq(
|
|
124
|
+
model_path: str = default_cfgs["parseq"]["url"],
|
|
125
|
+
load_in_8_bit: bool = False,
|
|
126
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
127
|
+
**kwargs: Any,
|
|
128
|
+
) -> PARSeq:
|
|
121
129
|
"""PARSeq architecture from
|
|
122
130
|
`"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
|
|
123
131
|
|
|
@@ -131,10 +139,11 @@ def parseq(model_path: str = default_cfgs["parseq"]["url"], load_in_8_bit: bool
|
|
|
131
139
|
----
|
|
132
140
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
133
141
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
142
|
+
engine_cfg: configuration for the inference engine
|
|
134
143
|
**kwargs: keyword arguments of the PARSeq architecture
|
|
135
144
|
|
|
136
145
|
Returns:
|
|
137
146
|
-------
|
|
138
147
|
text recognition architecture
|
|
139
148
|
"""
|
|
140
|
-
return _parseq("parseq", model_path, load_in_8_bit, **kwargs)
|
|
149
|
+
return _parseq("parseq", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -11,7 +11,7 @@ from scipy.special import softmax
|
|
|
11
11
|
|
|
12
12
|
from onnxtr.utils import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...engine import Engine
|
|
14
|
+
from ...engine import Engine, EngineConfig
|
|
15
15
|
from ..core import RecognitionPostProcessor
|
|
16
16
|
|
|
17
17
|
__all__ = ["SAR", "sar_resnet31"]
|
|
@@ -35,6 +35,7 @@ class SAR(Engine):
|
|
|
35
35
|
----
|
|
36
36
|
model_path: path to onnx model file
|
|
37
37
|
vocab: vocabulary used for encoding
|
|
38
|
+
engine_cfg: configuration for the inference engine
|
|
38
39
|
cfg: dictionary containing information about the model
|
|
39
40
|
**kwargs: additional arguments to be passed to `Engine`
|
|
40
41
|
"""
|
|
@@ -43,10 +44,11 @@ class SAR(Engine):
|
|
|
43
44
|
self,
|
|
44
45
|
model_path: str,
|
|
45
46
|
vocab: str,
|
|
47
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
46
48
|
cfg: Optional[Dict[str, Any]] = None,
|
|
47
49
|
**kwargs: Any,
|
|
48
50
|
) -> None:
|
|
49
|
-
super().__init__(url=model_path, **kwargs)
|
|
51
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
50
52
|
self.vocab = vocab
|
|
51
53
|
self.cfg = cfg
|
|
52
54
|
self.postprocessor = SARPostProcessor(self.vocab)
|
|
@@ -101,6 +103,7 @@ def _sar(
|
|
|
101
103
|
arch: str,
|
|
102
104
|
model_path: str,
|
|
103
105
|
load_in_8_bit: bool = False,
|
|
106
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
104
107
|
**kwargs: Any,
|
|
105
108
|
) -> SAR:
|
|
106
109
|
# Patch the config
|
|
@@ -113,11 +116,14 @@ def _sar(
|
|
|
113
116
|
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
114
117
|
|
|
115
118
|
# Build the model
|
|
116
|
-
return SAR(model_path, cfg=_cfg, **kwargs)
|
|
119
|
+
return SAR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
117
120
|
|
|
118
121
|
|
|
119
122
|
def sar_resnet31(
|
|
120
|
-
model_path: str = default_cfgs["sar_resnet31"]["url"],
|
|
123
|
+
model_path: str = default_cfgs["sar_resnet31"]["url"],
|
|
124
|
+
load_in_8_bit: bool = False,
|
|
125
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
126
|
+
**kwargs: Any,
|
|
121
127
|
) -> SAR:
|
|
122
128
|
"""SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
|
|
123
129
|
Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
@@ -132,10 +138,11 @@ def sar_resnet31(
|
|
|
132
138
|
----
|
|
133
139
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
134
140
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
141
|
+
engine_cfg: configuration for the inference engine
|
|
135
142
|
**kwargs: keyword arguments of the SAR architecture
|
|
136
143
|
|
|
137
144
|
Returns:
|
|
138
145
|
-------
|
|
139
146
|
text recognition architecture
|
|
140
147
|
"""
|
|
141
|
-
return _sar("sar_resnet31", model_path, load_in_8_bit, **kwargs)
|
|
148
|
+
return _sar("sar_resnet31", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -11,7 +11,7 @@ from scipy.special import softmax
|
|
|
11
11
|
|
|
12
12
|
from onnxtr.utils import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...engine import Engine
|
|
14
|
+
from ...engine import Engine, EngineConfig
|
|
15
15
|
from ..core import RecognitionPostProcessor
|
|
16
16
|
|
|
17
17
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
@@ -43,6 +43,7 @@ class ViTSTR(Engine):
|
|
|
43
43
|
----
|
|
44
44
|
model_path: path to onnx model file
|
|
45
45
|
vocab: vocabulary used for encoding
|
|
46
|
+
engine_cfg: configuration for the inference engine
|
|
46
47
|
cfg: dictionary containing information about the model
|
|
47
48
|
**kwargs: additional arguments to be passed to `Engine`
|
|
48
49
|
"""
|
|
@@ -51,10 +52,11 @@ class ViTSTR(Engine):
|
|
|
51
52
|
self,
|
|
52
53
|
model_path: str,
|
|
53
54
|
vocab: str,
|
|
55
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
54
56
|
cfg: Optional[Dict[str, Any]] = None,
|
|
55
57
|
**kwargs: Any,
|
|
56
58
|
) -> None:
|
|
57
|
-
super().__init__(url=model_path, **kwargs)
|
|
59
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
58
60
|
self.vocab = vocab
|
|
59
61
|
self.cfg = cfg
|
|
60
62
|
|
|
@@ -112,6 +114,7 @@ def _vitstr(
|
|
|
112
114
|
arch: str,
|
|
113
115
|
model_path: str,
|
|
114
116
|
load_in_8_bit: bool = False,
|
|
117
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
115
118
|
**kwargs: Any,
|
|
116
119
|
) -> ViTSTR:
|
|
117
120
|
# Patch the config
|
|
@@ -124,11 +127,14 @@ def _vitstr(
|
|
|
124
127
|
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
125
128
|
|
|
126
129
|
# Build the model
|
|
127
|
-
return ViTSTR(model_path, cfg=_cfg, **kwargs)
|
|
130
|
+
return ViTSTR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
128
131
|
|
|
129
132
|
|
|
130
133
|
def vitstr_small(
|
|
131
|
-
model_path: str = default_cfgs["vitstr_small"]["url"],
|
|
134
|
+
model_path: str = default_cfgs["vitstr_small"]["url"],
|
|
135
|
+
load_in_8_bit: bool = False,
|
|
136
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
137
|
+
**kwargs: Any,
|
|
132
138
|
) -> ViTSTR:
|
|
133
139
|
"""ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
|
|
134
140
|
<https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
@@ -143,17 +149,21 @@ def vitstr_small(
|
|
|
143
149
|
----
|
|
144
150
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
145
151
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
152
|
+
engine_cfg: configuration for the inference engine
|
|
146
153
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
147
154
|
|
|
148
155
|
Returns:
|
|
149
156
|
-------
|
|
150
157
|
text recognition architecture
|
|
151
158
|
"""
|
|
152
|
-
return _vitstr("vitstr_small", model_path, load_in_8_bit, **kwargs)
|
|
159
|
+
return _vitstr("vitstr_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
153
160
|
|
|
154
161
|
|
|
155
162
|
def vitstr_base(
|
|
156
|
-
model_path: str = default_cfgs["vitstr_base"]["url"],
|
|
163
|
+
model_path: str = default_cfgs["vitstr_base"]["url"],
|
|
164
|
+
load_in_8_bit: bool = False,
|
|
165
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
166
|
+
**kwargs: Any,
|
|
157
167
|
) -> ViTSTR:
|
|
158
168
|
"""ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
|
|
159
169
|
<https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
@@ -168,10 +178,11 @@ def vitstr_base(
|
|
|
168
178
|
----
|
|
169
179
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
170
180
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
181
|
+
engine_cfg: configuration for the inference engine
|
|
171
182
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
172
183
|
|
|
173
184
|
Returns:
|
|
174
185
|
-------
|
|
175
186
|
text recognition architecture
|
|
176
187
|
"""
|
|
177
|
-
return _vitstr("vitstr_base", model_path, load_in_8_bit, **kwargs)
|
|
188
|
+
return _vitstr("vitstr_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
|