onnxtr 0.1.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/io/elements.py +17 -4
- onnxtr/io/pdf.py +6 -3
- onnxtr/models/__init__.py +1 -0
- onnxtr/models/_utils.py +57 -20
- onnxtr/models/builder.py +24 -9
- onnxtr/models/classification/models/mobilenet.py +25 -7
- onnxtr/models/classification/predictor/base.py +1 -0
- onnxtr/models/classification/zoo.py +22 -7
- onnxtr/models/detection/_utils/__init__.py +1 -0
- onnxtr/models/detection/_utils/base.py +66 -0
- onnxtr/models/detection/models/differentiable_binarization.py +41 -11
- onnxtr/models/detection/models/fast.py +37 -9
- onnxtr/models/detection/models/linknet.py +39 -9
- onnxtr/models/detection/postprocessor/base.py +4 -3
- onnxtr/models/detection/predictor/base.py +15 -1
- onnxtr/models/detection/zoo.py +16 -3
- onnxtr/models/engine.py +75 -9
- onnxtr/models/predictor/base.py +69 -42
- onnxtr/models/predictor/predictor.py +22 -15
- onnxtr/models/recognition/models/crnn.py +39 -9
- onnxtr/models/recognition/models/master.py +19 -5
- onnxtr/models/recognition/models/parseq.py +20 -5
- onnxtr/models/recognition/models/sar.py +19 -5
- onnxtr/models/recognition/models/vitstr.py +31 -9
- onnxtr/models/recognition/zoo.py +12 -6
- onnxtr/models/zoo.py +22 -0
- onnxtr/py.typed +0 -0
- onnxtr/utils/geometry.py +33 -12
- onnxtr/version.py +1 -1
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/METADATA +81 -16
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/RECORD +35 -32
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/WHEEL +1 -1
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/top_level.txt +0 -1
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/LICENSE +0 -0
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/zip-safe +0 -0
onnxtr/models/predictor/base.py
CHANGED
|
@@ -8,10 +8,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
from onnxtr.models.builder import DocumentBuilder
|
|
11
|
-
from onnxtr.
|
|
11
|
+
from onnxtr.models.engine import EngineConfig
|
|
12
|
+
from onnxtr.utils.geometry import extract_crops, extract_rcrops, rotate_image
|
|
12
13
|
|
|
13
|
-
from .._utils import rectify_crops, rectify_loc_preds
|
|
14
|
-
from ..classification import crop_orientation_predictor
|
|
14
|
+
from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
|
|
15
|
+
from ..classification import crop_orientation_predictor, page_orientation_predictor
|
|
15
16
|
from ..classification.predictor import OrientationPredictor
|
|
16
17
|
from ..detection.zoo import ARCHS as DETECTION_ARCHS
|
|
17
18
|
from ..recognition.zoo import ARCHS as RECOGNITION_ARCHS
|
|
@@ -31,10 +32,15 @@ class _OCRPredictor:
|
|
|
31
32
|
accordingly. Doing so will improve performances for documents with page-uniform rotations.
|
|
32
33
|
preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
|
|
33
34
|
symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
|
|
35
|
+
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
|
|
36
|
+
page. Doing so will slightly deteriorate the overall latency.
|
|
37
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
38
|
+
clf_engine_cfg: configuration of the orientation classification engine
|
|
34
39
|
**kwargs: keyword args of `DocumentBuilder`
|
|
35
40
|
"""
|
|
36
41
|
|
|
37
42
|
crop_orientation_predictor: Optional[OrientationPredictor]
|
|
43
|
+
page_orientation_predictor: Optional[OrientationPredictor]
|
|
38
44
|
|
|
39
45
|
def __init__(
|
|
40
46
|
self,
|
|
@@ -42,16 +48,75 @@ class _OCRPredictor:
|
|
|
42
48
|
straighten_pages: bool = False,
|
|
43
49
|
preserve_aspect_ratio: bool = True,
|
|
44
50
|
symmetric_pad: bool = True,
|
|
51
|
+
detect_orientation: bool = False,
|
|
52
|
+
load_in_8_bit: bool = False,
|
|
53
|
+
clf_engine_cfg: EngineConfig = EngineConfig(),
|
|
45
54
|
**kwargs: Any,
|
|
46
55
|
) -> None:
|
|
47
56
|
self.assume_straight_pages = assume_straight_pages
|
|
48
57
|
self.straighten_pages = straighten_pages
|
|
49
|
-
self.crop_orientation_predictor =
|
|
58
|
+
self.crop_orientation_predictor = (
|
|
59
|
+
None
|
|
60
|
+
if assume_straight_pages
|
|
61
|
+
else crop_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
|
|
62
|
+
)
|
|
63
|
+
self.page_orientation_predictor = (
|
|
64
|
+
page_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
|
|
65
|
+
if detect_orientation or straighten_pages or not assume_straight_pages
|
|
66
|
+
else None
|
|
67
|
+
)
|
|
50
68
|
self.doc_builder = DocumentBuilder(**kwargs)
|
|
51
69
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
52
70
|
self.symmetric_pad = symmetric_pad
|
|
53
71
|
self.hooks: List[Callable] = []
|
|
54
72
|
|
|
73
|
+
def _general_page_orientations(
|
|
74
|
+
self,
|
|
75
|
+
pages: List[np.ndarray],
|
|
76
|
+
) -> List[Tuple[int, float]]:
|
|
77
|
+
_, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
|
|
78
|
+
# Flatten to list of tuples with (value, confidence)
|
|
79
|
+
page_orientations = [
|
|
80
|
+
(orientation, prob)
|
|
81
|
+
for page_classes, page_probs in zip(classes, probs)
|
|
82
|
+
for orientation, prob in zip(page_classes, page_probs)
|
|
83
|
+
]
|
|
84
|
+
return page_orientations
|
|
85
|
+
|
|
86
|
+
def _get_orientations(
|
|
87
|
+
self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
|
|
88
|
+
) -> Tuple[List[Tuple[int, float]], List[int]]:
|
|
89
|
+
general_pages_orientations = self._general_page_orientations(pages)
|
|
90
|
+
origin_page_orientations = [
|
|
91
|
+
estimate_orientation(seq_map, general_orientation)
|
|
92
|
+
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
|
|
93
|
+
]
|
|
94
|
+
return general_pages_orientations, origin_page_orientations
|
|
95
|
+
|
|
96
|
+
def _straighten_pages(
|
|
97
|
+
self,
|
|
98
|
+
pages: List[np.ndarray],
|
|
99
|
+
seg_maps: List[np.ndarray],
|
|
100
|
+
general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
|
|
101
|
+
origin_pages_orientations: Optional[List[int]] = None,
|
|
102
|
+
) -> List[np.ndarray]:
|
|
103
|
+
general_pages_orientations = (
|
|
104
|
+
general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
|
|
105
|
+
)
|
|
106
|
+
origin_pages_orientations = (
|
|
107
|
+
origin_pages_orientations
|
|
108
|
+
if origin_pages_orientations
|
|
109
|
+
else [
|
|
110
|
+
estimate_orientation(seq_map, general_orientation)
|
|
111
|
+
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
|
|
112
|
+
]
|
|
113
|
+
)
|
|
114
|
+
return [
|
|
115
|
+
# We exapnd if the page is wider than tall and the angle is 90 or -90
|
|
116
|
+
rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90)
|
|
117
|
+
for page, angle in zip(pages, origin_pages_orientations)
|
|
118
|
+
]
|
|
119
|
+
|
|
55
120
|
@staticmethod
|
|
56
121
|
def _generate_crops(
|
|
57
122
|
pages: List[np.ndarray],
|
|
@@ -106,44 +171,6 @@ class _OCRPredictor:
|
|
|
106
171
|
]
|
|
107
172
|
return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
|
|
108
173
|
|
|
109
|
-
def _remove_padding(
|
|
110
|
-
self,
|
|
111
|
-
pages: List[np.ndarray],
|
|
112
|
-
loc_preds: List[np.ndarray],
|
|
113
|
-
) -> List[np.ndarray]:
|
|
114
|
-
if self.preserve_aspect_ratio:
|
|
115
|
-
# Rectify loc_preds to remove padding
|
|
116
|
-
rectified_preds = []
|
|
117
|
-
for page, loc_pred in zip(pages, loc_preds):
|
|
118
|
-
h, w = page.shape[0], page.shape[1]
|
|
119
|
-
if h > w:
|
|
120
|
-
# y unchanged, dilate x coord
|
|
121
|
-
if self.symmetric_pad:
|
|
122
|
-
if self.assume_straight_pages:
|
|
123
|
-
loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
|
|
124
|
-
else:
|
|
125
|
-
loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
|
|
126
|
-
else:
|
|
127
|
-
if self.assume_straight_pages:
|
|
128
|
-
loc_pred[:, [0, 2]] *= h / w
|
|
129
|
-
else:
|
|
130
|
-
loc_pred[:, :, 0] *= h / w
|
|
131
|
-
elif w > h:
|
|
132
|
-
# x unchanged, dilate y coord
|
|
133
|
-
if self.symmetric_pad:
|
|
134
|
-
if self.assume_straight_pages:
|
|
135
|
-
loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
|
|
136
|
-
else:
|
|
137
|
-
loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
|
|
138
|
-
else:
|
|
139
|
-
if self.assume_straight_pages:
|
|
140
|
-
loc_pred[:, [1, 3]] *= w / h
|
|
141
|
-
else:
|
|
142
|
-
loc_pred[:, :, 1] *= w / h
|
|
143
|
-
rectified_preds.append(loc_pred)
|
|
144
|
-
return rectified_preds
|
|
145
|
-
return loc_preds
|
|
146
|
-
|
|
147
174
|
@staticmethod
|
|
148
175
|
def _process_predictions(
|
|
149
176
|
loc_preds: List[np.ndarray],
|
|
@@ -8,10 +8,11 @@ from typing import Any, List
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
from onnxtr.io.elements import Document
|
|
11
|
-
from onnxtr.models._utils import
|
|
11
|
+
from onnxtr.models._utils import get_language
|
|
12
12
|
from onnxtr.models.detection.predictor import DetectionPredictor
|
|
13
|
+
from onnxtr.models.engine import EngineConfig
|
|
13
14
|
from onnxtr.models.recognition.predictor import RecognitionPredictor
|
|
14
|
-
from onnxtr.utils.geometry import
|
|
15
|
+
from onnxtr.utils.geometry import detach_scores
|
|
15
16
|
from onnxtr.utils.repr import NestedObject
|
|
16
17
|
|
|
17
18
|
from .base import _OCRPredictor
|
|
@@ -35,6 +36,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
35
36
|
page. Doing so will slightly deteriorate the overall latency.
|
|
36
37
|
detect_language: if True, the language prediction will be added to the predictions for each
|
|
37
38
|
page. Doing so will slightly deteriorate the overall latency.
|
|
39
|
+
clf_engine_cfg: configuration of the orientation classification engine
|
|
38
40
|
**kwargs: keyword args of `DocumentBuilder`
|
|
39
41
|
"""
|
|
40
42
|
|
|
@@ -50,12 +52,20 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
50
52
|
symmetric_pad: bool = True,
|
|
51
53
|
detect_orientation: bool = False,
|
|
52
54
|
detect_language: bool = False,
|
|
55
|
+
clf_engine_cfg: EngineConfig = EngineConfig(),
|
|
53
56
|
**kwargs: Any,
|
|
54
57
|
) -> None:
|
|
55
58
|
self.det_predictor = det_predictor
|
|
56
59
|
self.reco_predictor = reco_predictor
|
|
57
60
|
_OCRPredictor.__init__(
|
|
58
|
-
self,
|
|
61
|
+
self,
|
|
62
|
+
assume_straight_pages,
|
|
63
|
+
straighten_pages,
|
|
64
|
+
preserve_aspect_ratio,
|
|
65
|
+
symmetric_pad,
|
|
66
|
+
detect_orientation,
|
|
67
|
+
clf_engine_cfg=clf_engine_cfg,
|
|
68
|
+
**kwargs,
|
|
59
69
|
)
|
|
60
70
|
self.detect_orientation = detect_orientation
|
|
61
71
|
self.detect_language = detect_language
|
|
@@ -80,26 +90,22 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
80
90
|
for out_map in out_maps
|
|
81
91
|
]
|
|
82
92
|
if self.detect_orientation:
|
|
83
|
-
|
|
93
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
84
94
|
orientations = [
|
|
85
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
95
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
86
96
|
]
|
|
87
97
|
else:
|
|
88
98
|
orientations = None
|
|
99
|
+
general_pages_orientations = None
|
|
100
|
+
origin_pages_orientations = None
|
|
89
101
|
if self.straighten_pages:
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
if self.detect_orientation
|
|
93
|
-
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
94
|
-
)
|
|
95
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
102
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
103
|
+
|
|
96
104
|
# forward again to get predictions on straight pages
|
|
97
105
|
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
98
106
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
# Rectify crops if aspect ratio
|
|
102
|
-
loc_preds = self._remove_padding(pages, loc_preds)
|
|
107
|
+
# Detach objectness scores from loc_preds
|
|
108
|
+
loc_preds, objectness_scores = detach_scores(loc_preds) # type: ignore[arg-type]
|
|
103
109
|
|
|
104
110
|
# Apply hooks to loc_preds if any
|
|
105
111
|
for hook in self.hooks:
|
|
@@ -136,6 +142,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
136
142
|
out = self.doc_builder(
|
|
137
143
|
pages,
|
|
138
144
|
boxes,
|
|
145
|
+
objectness_scores,
|
|
139
146
|
text_preds,
|
|
140
147
|
origin_page_shapes, # type: ignore[arg-type]
|
|
141
148
|
crop_orientations,
|
|
@@ -12,7 +12,7 @@ from scipy.special import softmax
|
|
|
12
12
|
|
|
13
13
|
from onnxtr.utils import VOCABS
|
|
14
14
|
|
|
15
|
-
from ...engine import Engine
|
|
15
|
+
from ...engine import Engine, EngineConfig
|
|
16
16
|
from ..core import RecognitionPostProcessor
|
|
17
17
|
|
|
18
18
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
@@ -24,6 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
24
24
|
"input_shape": (3, 32, 128),
|
|
25
25
|
"vocab": VOCABS["legacy_french"],
|
|
26
26
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_vgg16_bn-662979cc.onnx",
|
|
27
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/crnn_vgg16_bn_static_8_bit-bce050c7.onnx",
|
|
27
28
|
},
|
|
28
29
|
"crnn_mobilenet_v3_small": {
|
|
29
30
|
"mean": (0.694, 0.695, 0.693),
|
|
@@ -31,6 +32,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
31
32
|
"input_shape": (3, 32, 128),
|
|
32
33
|
"vocab": VOCABS["french"],
|
|
33
34
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_mobilenet_v3_small-bded4d49.onnx",
|
|
35
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/crnn_mobilenet_v3_small_static_8_bit-4949006f.onnx",
|
|
34
36
|
},
|
|
35
37
|
"crnn_mobilenet_v3_large": {
|
|
36
38
|
"mean": (0.694, 0.695, 0.693),
|
|
@@ -38,6 +40,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
38
40
|
"input_shape": (3, 32, 128),
|
|
39
41
|
"vocab": VOCABS["french"],
|
|
40
42
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_mobilenet_v3_large-d42e8185.onnx",
|
|
43
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/crnn_mobilenet_v3_large_static_8_bit-459e856d.onnx",
|
|
41
44
|
},
|
|
42
45
|
}
|
|
43
46
|
|
|
@@ -110,6 +113,7 @@ class CRNN(Engine):
|
|
|
110
113
|
----
|
|
111
114
|
model_path: path or url to onnx model file
|
|
112
115
|
vocab: vocabulary used for encoding
|
|
116
|
+
engine_cfg: configuration for the inference engine
|
|
113
117
|
cfg: configuration dictionary
|
|
114
118
|
**kwargs: additional arguments to be passed to `Engine`
|
|
115
119
|
"""
|
|
@@ -120,10 +124,11 @@ class CRNN(Engine):
|
|
|
120
124
|
self,
|
|
121
125
|
model_path: str,
|
|
122
126
|
vocab: str,
|
|
127
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
123
128
|
cfg: Optional[Dict[str, Any]] = None,
|
|
124
129
|
**kwargs: Any,
|
|
125
130
|
) -> None:
|
|
126
|
-
super().__init__(url=model_path, **kwargs)
|
|
131
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
127
132
|
self.vocab = vocab
|
|
128
133
|
self.cfg = cfg
|
|
129
134
|
self.postprocessor = CRNNPostProcessor(self.vocab)
|
|
@@ -148,6 +153,8 @@ class CRNN(Engine):
|
|
|
148
153
|
def _crnn(
|
|
149
154
|
arch: str,
|
|
150
155
|
model_path: str,
|
|
156
|
+
load_in_8_bit: bool = False,
|
|
157
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
151
158
|
**kwargs: Any,
|
|
152
159
|
) -> CRNN:
|
|
153
160
|
kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
|
|
@@ -155,12 +162,19 @@ def _crnn(
|
|
|
155
162
|
_cfg = deepcopy(default_cfgs[arch])
|
|
156
163
|
_cfg["vocab"] = kwargs["vocab"]
|
|
157
164
|
_cfg["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
|
|
165
|
+
# Patch the url
|
|
166
|
+
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
158
167
|
|
|
159
168
|
# Build the model
|
|
160
|
-
return CRNN(model_path, cfg=_cfg, **kwargs)
|
|
169
|
+
return CRNN(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
161
170
|
|
|
162
171
|
|
|
163
|
-
def crnn_vgg16_bn(
|
|
172
|
+
def crnn_vgg16_bn(
|
|
173
|
+
model_path: str = default_cfgs["crnn_vgg16_bn"]["url"],
|
|
174
|
+
load_in_8_bit: bool = False,
|
|
175
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
176
|
+
**kwargs: Any,
|
|
177
|
+
) -> CRNN:
|
|
164
178
|
"""CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
165
179
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
166
180
|
|
|
@@ -173,16 +187,23 @@ def crnn_vgg16_bn(model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], **kwar
|
|
|
173
187
|
Args:
|
|
174
188
|
----
|
|
175
189
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
190
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
191
|
+
engine_cfg: configuration for the inference engine
|
|
176
192
|
**kwargs: keyword arguments of the CRNN architecture
|
|
177
193
|
|
|
178
194
|
Returns:
|
|
179
195
|
-------
|
|
180
196
|
text recognition architecture
|
|
181
197
|
"""
|
|
182
|
-
return _crnn("crnn_vgg16_bn", model_path, **kwargs)
|
|
198
|
+
return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
183
199
|
|
|
184
200
|
|
|
185
|
-
def crnn_mobilenet_v3_small(
|
|
201
|
+
def crnn_mobilenet_v3_small(
|
|
202
|
+
model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"],
|
|
203
|
+
load_in_8_bit: bool = False,
|
|
204
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
205
|
+
**kwargs: Any,
|
|
206
|
+
) -> CRNN:
|
|
186
207
|
"""CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
187
208
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
188
209
|
|
|
@@ -195,16 +216,23 @@ def crnn_mobilenet_v3_small(model_path: str = default_cfgs["crnn_mobilenet_v3_sm
|
|
|
195
216
|
Args:
|
|
196
217
|
----
|
|
197
218
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
219
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
220
|
+
engine_cfg: configuration for the inference engine
|
|
198
221
|
**kwargs: keyword arguments of the CRNN architecture
|
|
199
222
|
|
|
200
223
|
Returns:
|
|
201
224
|
-------
|
|
202
225
|
text recognition architecture
|
|
203
226
|
"""
|
|
204
|
-
return _crnn("crnn_mobilenet_v3_small", model_path, **kwargs)
|
|
227
|
+
return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
205
228
|
|
|
206
229
|
|
|
207
|
-
def crnn_mobilenet_v3_large(
|
|
230
|
+
def crnn_mobilenet_v3_large(
|
|
231
|
+
model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"],
|
|
232
|
+
load_in_8_bit: bool = False,
|
|
233
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
234
|
+
**kwargs: Any,
|
|
235
|
+
) -> CRNN:
|
|
208
236
|
"""CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
209
237
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
210
238
|
|
|
@@ -217,10 +245,12 @@ def crnn_mobilenet_v3_large(model_path: str = default_cfgs["crnn_mobilenet_v3_la
|
|
|
217
245
|
Args:
|
|
218
246
|
----
|
|
219
247
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
248
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
249
|
+
engine_cfg: configuration for the inference engine
|
|
220
250
|
**kwargs: keyword arguments of the CRNN architecture
|
|
221
251
|
|
|
222
252
|
Returns:
|
|
223
253
|
-------
|
|
224
254
|
text recognition architecture
|
|
225
255
|
"""
|
|
226
|
-
return _crnn("crnn_mobilenet_v3_large", model_path, **kwargs)
|
|
256
|
+
return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -11,7 +11,7 @@ from scipy.special import softmax
|
|
|
11
11
|
|
|
12
12
|
from onnxtr.utils import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...engine import Engine
|
|
14
|
+
from ...engine import Engine, EngineConfig
|
|
15
15
|
from ..core import RecognitionPostProcessor
|
|
16
16
|
|
|
17
17
|
__all__ = ["MASTER", "master"]
|
|
@@ -24,6 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
24
24
|
"input_shape": (3, 32, 128),
|
|
25
25
|
"vocab": VOCABS["french"],
|
|
26
26
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/master-b1287fcd.onnx",
|
|
27
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/master_dynamic_8_bit-d8bd8206.onnx",
|
|
27
28
|
},
|
|
28
29
|
}
|
|
29
30
|
|
|
@@ -35,6 +36,7 @@ class MASTER(Engine):
|
|
|
35
36
|
----
|
|
36
37
|
model_path: path or url to onnx model file
|
|
37
38
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
39
|
+
engine_cfg: configuration for the inference engine
|
|
38
40
|
cfg: dictionary containing information about the model
|
|
39
41
|
**kwargs: additional arguments to be passed to `Engine`
|
|
40
42
|
"""
|
|
@@ -43,10 +45,11 @@ class MASTER(Engine):
|
|
|
43
45
|
self,
|
|
44
46
|
model_path: str,
|
|
45
47
|
vocab: str,
|
|
48
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
46
49
|
cfg: Optional[Dict[str, Any]] = None,
|
|
47
50
|
**kwargs: Any,
|
|
48
51
|
) -> None:
|
|
49
|
-
super().__init__(url=model_path, **kwargs)
|
|
52
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
50
53
|
|
|
51
54
|
self.vocab = vocab
|
|
52
55
|
self.cfg = cfg
|
|
@@ -112,6 +115,8 @@ class MASTERPostProcessor(RecognitionPostProcessor):
|
|
|
112
115
|
def _master(
|
|
113
116
|
arch: str,
|
|
114
117
|
model_path: str,
|
|
118
|
+
load_in_8_bit: bool = False,
|
|
119
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
115
120
|
**kwargs: Any,
|
|
116
121
|
) -> MASTER:
|
|
117
122
|
# Patch the config
|
|
@@ -120,11 +125,18 @@ def _master(
|
|
|
120
125
|
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
121
126
|
|
|
122
127
|
kwargs["vocab"] = _cfg["vocab"]
|
|
128
|
+
# Patch the url
|
|
129
|
+
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
123
130
|
|
|
124
|
-
return MASTER(model_path, cfg=_cfg, **kwargs)
|
|
131
|
+
return MASTER(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
125
132
|
|
|
126
133
|
|
|
127
|
-
def master(
|
|
134
|
+
def master(
|
|
135
|
+
model_path: str = default_cfgs["master"]["url"],
|
|
136
|
+
load_in_8_bit: bool = False,
|
|
137
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
138
|
+
**kwargs: Any,
|
|
139
|
+
) -> MASTER:
|
|
128
140
|
"""MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
129
141
|
|
|
130
142
|
>>> import numpy as np
|
|
@@ -136,10 +148,12 @@ def master(model_path: str = default_cfgs["master"]["url"], **kwargs: Any) -> MA
|
|
|
136
148
|
Args:
|
|
137
149
|
----
|
|
138
150
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
151
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
152
|
+
engine_cfg: configuration for the inference engine
|
|
139
153
|
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
140
154
|
|
|
141
155
|
Returns:
|
|
142
156
|
-------
|
|
143
157
|
text recognition architecture
|
|
144
158
|
"""
|
|
145
|
-
return _master("master", model_path, **kwargs)
|
|
159
|
+
return _master("master", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -11,7 +11,7 @@ from scipy.special import softmax
|
|
|
11
11
|
|
|
12
12
|
from onnxtr.utils import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...engine import Engine
|
|
14
|
+
from ...engine import Engine, EngineConfig
|
|
15
15
|
from ..core import RecognitionPostProcessor
|
|
16
16
|
|
|
17
17
|
__all__ = ["PARSeq", "parseq"]
|
|
@@ -23,6 +23,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
23
23
|
"input_shape": (3, 32, 128),
|
|
24
24
|
"vocab": VOCABS["french"],
|
|
25
25
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/parseq-00b40714.onnx",
|
|
26
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/parseq_dynamic_8_bit-5b04d9f7.onnx",
|
|
26
27
|
},
|
|
27
28
|
}
|
|
28
29
|
|
|
@@ -32,7 +33,9 @@ class PARSeq(Engine):
|
|
|
32
33
|
|
|
33
34
|
Args:
|
|
34
35
|
----
|
|
36
|
+
model_path: path to onnx model file
|
|
35
37
|
vocab: vocabulary used for encoding
|
|
38
|
+
engine_cfg: configuration for the inference engine
|
|
36
39
|
cfg: dictionary containing information about the model
|
|
37
40
|
**kwargs: additional arguments to be passed to `Engine`
|
|
38
41
|
"""
|
|
@@ -41,10 +44,11 @@ class PARSeq(Engine):
|
|
|
41
44
|
self,
|
|
42
45
|
model_path: str,
|
|
43
46
|
vocab: str,
|
|
47
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
44
48
|
cfg: Optional[Dict[str, Any]] = None,
|
|
45
49
|
**kwargs: Any,
|
|
46
50
|
) -> None:
|
|
47
|
-
super().__init__(url=model_path, **kwargs)
|
|
51
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
48
52
|
self.vocab = vocab
|
|
49
53
|
self.cfg = cfg
|
|
50
54
|
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
|
|
@@ -99,6 +103,8 @@ class PARSeqPostProcessor(RecognitionPostProcessor):
|
|
|
99
103
|
def _parseq(
|
|
100
104
|
arch: str,
|
|
101
105
|
model_path: str,
|
|
106
|
+
load_in_8_bit: bool = False,
|
|
107
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
102
108
|
**kwargs: Any,
|
|
103
109
|
) -> PARSeq:
|
|
104
110
|
# Patch the config
|
|
@@ -107,12 +113,19 @@ def _parseq(
|
|
|
107
113
|
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
|
|
108
114
|
|
|
109
115
|
kwargs["vocab"] = _cfg["vocab"]
|
|
116
|
+
# Patch the url
|
|
117
|
+
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
110
118
|
|
|
111
119
|
# Build the model
|
|
112
|
-
return PARSeq(model_path, cfg=_cfg, **kwargs)
|
|
120
|
+
return PARSeq(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
113
121
|
|
|
114
122
|
|
|
115
|
-
def parseq(
|
|
123
|
+
def parseq(
|
|
124
|
+
model_path: str = default_cfgs["parseq"]["url"],
|
|
125
|
+
load_in_8_bit: bool = False,
|
|
126
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
127
|
+
**kwargs: Any,
|
|
128
|
+
) -> PARSeq:
|
|
116
129
|
"""PARSeq architecture from
|
|
117
130
|
`"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
|
|
118
131
|
|
|
@@ -125,10 +138,12 @@ def parseq(model_path: str = default_cfgs["parseq"]["url"], **kwargs: Any) -> PA
|
|
|
125
138
|
Args:
|
|
126
139
|
----
|
|
127
140
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
141
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
142
|
+
engine_cfg: configuration for the inference engine
|
|
128
143
|
**kwargs: keyword arguments of the PARSeq architecture
|
|
129
144
|
|
|
130
145
|
Returns:
|
|
131
146
|
-------
|
|
132
147
|
text recognition architecture
|
|
133
148
|
"""
|
|
134
|
-
return _parseq("parseq", model_path, **kwargs)
|
|
149
|
+
return _parseq("parseq", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -11,7 +11,7 @@ from scipy.special import softmax
|
|
|
11
11
|
|
|
12
12
|
from onnxtr.utils import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...engine import Engine
|
|
14
|
+
from ...engine import Engine, EngineConfig
|
|
15
15
|
from ..core import RecognitionPostProcessor
|
|
16
16
|
|
|
17
17
|
__all__ = ["SAR", "sar_resnet31"]
|
|
@@ -23,6 +23,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
23
23
|
"input_shape": (3, 32, 128),
|
|
24
24
|
"vocab": VOCABS["french"],
|
|
25
25
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/sar_resnet31-395f8005.onnx",
|
|
26
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/sar_resnet31_static_8_bit-c07316bc.onnx",
|
|
26
27
|
},
|
|
27
28
|
}
|
|
28
29
|
|
|
@@ -34,6 +35,7 @@ class SAR(Engine):
|
|
|
34
35
|
----
|
|
35
36
|
model_path: path to onnx model file
|
|
36
37
|
vocab: vocabulary used for encoding
|
|
38
|
+
engine_cfg: configuration for the inference engine
|
|
37
39
|
cfg: dictionary containing information about the model
|
|
38
40
|
**kwargs: additional arguments to be passed to `Engine`
|
|
39
41
|
"""
|
|
@@ -42,10 +44,11 @@ class SAR(Engine):
|
|
|
42
44
|
self,
|
|
43
45
|
model_path: str,
|
|
44
46
|
vocab: str,
|
|
47
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
45
48
|
cfg: Optional[Dict[str, Any]] = None,
|
|
46
49
|
**kwargs: Any,
|
|
47
50
|
) -> None:
|
|
48
|
-
super().__init__(url=model_path, **kwargs)
|
|
51
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
49
52
|
self.vocab = vocab
|
|
50
53
|
self.cfg = cfg
|
|
51
54
|
self.postprocessor = SARPostProcessor(self.vocab)
|
|
@@ -99,6 +102,8 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
99
102
|
def _sar(
|
|
100
103
|
arch: str,
|
|
101
104
|
model_path: str,
|
|
105
|
+
load_in_8_bit: bool = False,
|
|
106
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
102
107
|
**kwargs: Any,
|
|
103
108
|
) -> SAR:
|
|
104
109
|
# Patch the config
|
|
@@ -107,12 +112,19 @@ def _sar(
|
|
|
107
112
|
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
|
|
108
113
|
|
|
109
114
|
kwargs["vocab"] = _cfg["vocab"]
|
|
115
|
+
# Patch the url
|
|
116
|
+
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
110
117
|
|
|
111
118
|
# Build the model
|
|
112
|
-
return SAR(model_path, cfg=_cfg, **kwargs)
|
|
119
|
+
return SAR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
|
|
113
120
|
|
|
114
121
|
|
|
115
|
-
def sar_resnet31(
|
|
122
|
+
def sar_resnet31(
|
|
123
|
+
model_path: str = default_cfgs["sar_resnet31"]["url"],
|
|
124
|
+
load_in_8_bit: bool = False,
|
|
125
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
126
|
+
**kwargs: Any,
|
|
127
|
+
) -> SAR:
|
|
116
128
|
"""SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
|
|
117
129
|
Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
118
130
|
|
|
@@ -125,10 +137,12 @@ def sar_resnet31(model_path: str = default_cfgs["sar_resnet31"]["url"], **kwargs
|
|
|
125
137
|
Args:
|
|
126
138
|
----
|
|
127
139
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
140
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
141
|
+
engine_cfg: configuration for the inference engine
|
|
128
142
|
**kwargs: keyword arguments of the SAR architecture
|
|
129
143
|
|
|
130
144
|
Returns:
|
|
131
145
|
-------
|
|
132
146
|
text recognition architecture
|
|
133
147
|
"""
|
|
134
|
-
return _sar("sar_resnet31", model_path, **kwargs)
|
|
148
|
+
return _sar("sar_resnet31", model_path, load_in_8_bit, engine_cfg, **kwargs)
|