onnxtr 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/contrib/__init__.py +1 -0
- onnxtr/contrib/artefacts.py +6 -8
- onnxtr/contrib/base.py +7 -16
- onnxtr/file_utils.py +1 -3
- onnxtr/io/elements.py +54 -60
- onnxtr/io/html.py +0 -2
- onnxtr/io/image.py +1 -4
- onnxtr/io/pdf.py +3 -5
- onnxtr/io/reader.py +4 -10
- onnxtr/models/_utils.py +10 -17
- onnxtr/models/builder.py +17 -30
- onnxtr/models/classification/models/mobilenet.py +7 -12
- onnxtr/models/classification/predictor/base.py +6 -7
- onnxtr/models/classification/zoo.py +25 -11
- onnxtr/models/detection/_utils/base.py +3 -7
- onnxtr/models/detection/core.py +2 -8
- onnxtr/models/detection/models/differentiable_binarization.py +10 -17
- onnxtr/models/detection/models/fast.py +10 -17
- onnxtr/models/detection/models/linknet.py +10 -17
- onnxtr/models/detection/postprocessor/base.py +3 -9
- onnxtr/models/detection/predictor/base.py +4 -5
- onnxtr/models/detection/zoo.py +20 -6
- onnxtr/models/engine.py +9 -9
- onnxtr/models/factory/hub.py +3 -7
- onnxtr/models/predictor/base.py +29 -30
- onnxtr/models/predictor/predictor.py +4 -5
- onnxtr/models/preprocessor/base.py +8 -12
- onnxtr/models/recognition/core.py +0 -1
- onnxtr/models/recognition/models/crnn.py +11 -23
- onnxtr/models/recognition/models/master.py +9 -15
- onnxtr/models/recognition/models/parseq.py +8 -12
- onnxtr/models/recognition/models/sar.py +8 -12
- onnxtr/models/recognition/models/vitstr.py +9 -15
- onnxtr/models/recognition/predictor/_utils.py +6 -9
- onnxtr/models/recognition/predictor/base.py +3 -3
- onnxtr/models/recognition/utils.py +2 -7
- onnxtr/models/recognition/zoo.py +19 -7
- onnxtr/models/zoo.py +7 -9
- onnxtr/transforms/base.py +17 -6
- onnxtr/utils/common_types.py +7 -8
- onnxtr/utils/data.py +7 -11
- onnxtr/utils/fonts.py +1 -6
- onnxtr/utils/geometry.py +18 -49
- onnxtr/utils/multithreading.py +3 -5
- onnxtr/utils/reconstitution.py +139 -38
- onnxtr/utils/repr.py +1 -2
- onnxtr/utils/visualization.py +12 -21
- onnxtr/utils/vocabs.py +1 -2
- onnxtr/version.py +1 -1
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/METADATA +71 -41
- onnxtr-0.6.0.dist-info/RECORD +75 -0
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/WHEEL +1 -1
- onnxtr-0.5.0.dist-info/RECORD +0 -75
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/LICENSE +0 -0
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/top_level.txt +0 -0
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/zip-safe +0 -0
onnxtr/models/predictor/base.py
CHANGED
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
|
|
@@ -24,7 +25,6 @@ class _OCRPredictor:
|
|
|
24
25
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
26
|
|
|
26
27
|
Args:
|
|
27
|
-
----
|
|
28
28
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
29
29
|
without rotated textual elements.
|
|
30
30
|
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
|
|
@@ -39,8 +39,8 @@ class _OCRPredictor:
|
|
|
39
39
|
**kwargs: keyword args of `DocumentBuilder`
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
-
crop_orientation_predictor:
|
|
43
|
-
page_orientation_predictor:
|
|
42
|
+
crop_orientation_predictor: OrientationPredictor | None
|
|
43
|
+
page_orientation_predictor: OrientationPredictor | None
|
|
44
44
|
|
|
45
45
|
def __init__(
|
|
46
46
|
self,
|
|
@@ -50,7 +50,7 @@ class _OCRPredictor:
|
|
|
50
50
|
symmetric_pad: bool = True,
|
|
51
51
|
detect_orientation: bool = False,
|
|
52
52
|
load_in_8_bit: bool = False,
|
|
53
|
-
clf_engine_cfg:
|
|
53
|
+
clf_engine_cfg: EngineConfig | None = None,
|
|
54
54
|
**kwargs: Any,
|
|
55
55
|
) -> None:
|
|
56
56
|
self.assume_straight_pages = assume_straight_pages
|
|
@@ -74,12 +74,12 @@ class _OCRPredictor:
|
|
|
74
74
|
self.doc_builder = DocumentBuilder(**kwargs)
|
|
75
75
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
76
76
|
self.symmetric_pad = symmetric_pad
|
|
77
|
-
self.hooks:
|
|
77
|
+
self.hooks: list[Callable] = []
|
|
78
78
|
|
|
79
79
|
def _general_page_orientations(
|
|
80
80
|
self,
|
|
81
|
-
pages:
|
|
82
|
-
) ->
|
|
81
|
+
pages: list[np.ndarray],
|
|
82
|
+
) -> list[tuple[int, float]]:
|
|
83
83
|
_, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
|
|
84
84
|
# Flatten to list of tuples with (value, confidence)
|
|
85
85
|
page_orientations = [
|
|
@@ -90,8 +90,8 @@ class _OCRPredictor:
|
|
|
90
90
|
return page_orientations
|
|
91
91
|
|
|
92
92
|
def _get_orientations(
|
|
93
|
-
self, pages:
|
|
94
|
-
) ->
|
|
93
|
+
self, pages: list[np.ndarray], seg_maps: list[np.ndarray]
|
|
94
|
+
) -> tuple[list[tuple[int, float]], list[int]]:
|
|
95
95
|
general_pages_orientations = self._general_page_orientations(pages)
|
|
96
96
|
origin_page_orientations = [
|
|
97
97
|
estimate_orientation(seq_map, general_orientation)
|
|
@@ -101,11 +101,11 @@ class _OCRPredictor:
|
|
|
101
101
|
|
|
102
102
|
def _straighten_pages(
|
|
103
103
|
self,
|
|
104
|
-
pages:
|
|
105
|
-
seg_maps:
|
|
106
|
-
general_pages_orientations:
|
|
107
|
-
origin_pages_orientations:
|
|
108
|
-
) ->
|
|
104
|
+
pages: list[np.ndarray],
|
|
105
|
+
seg_maps: list[np.ndarray],
|
|
106
|
+
general_pages_orientations: list[tuple[int, float]] | None = None,
|
|
107
|
+
origin_pages_orientations: list[int] | None = None,
|
|
108
|
+
) -> list[np.ndarray]:
|
|
109
109
|
general_pages_orientations = (
|
|
110
110
|
general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
|
|
111
111
|
)
|
|
@@ -125,12 +125,12 @@ class _OCRPredictor:
|
|
|
125
125
|
|
|
126
126
|
@staticmethod
|
|
127
127
|
def _generate_crops(
|
|
128
|
-
pages:
|
|
129
|
-
loc_preds:
|
|
128
|
+
pages: list[np.ndarray],
|
|
129
|
+
loc_preds: list[np.ndarray],
|
|
130
130
|
channels_last: bool,
|
|
131
131
|
assume_straight_pages: bool = False,
|
|
132
132
|
assume_horizontal: bool = False,
|
|
133
|
-
) ->
|
|
133
|
+
) -> list[list[np.ndarray]]:
|
|
134
134
|
if assume_straight_pages:
|
|
135
135
|
crops = [
|
|
136
136
|
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
|
|
@@ -145,12 +145,12 @@ class _OCRPredictor:
|
|
|
145
145
|
|
|
146
146
|
@staticmethod
|
|
147
147
|
def _prepare_crops(
|
|
148
|
-
pages:
|
|
149
|
-
loc_preds:
|
|
148
|
+
pages: list[np.ndarray],
|
|
149
|
+
loc_preds: list[np.ndarray],
|
|
150
150
|
channels_last: bool,
|
|
151
151
|
assume_straight_pages: bool = False,
|
|
152
152
|
assume_horizontal: bool = False,
|
|
153
|
-
) ->
|
|
153
|
+
) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
|
|
154
154
|
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
|
|
155
155
|
|
|
156
156
|
# Avoid sending zero-sized crops
|
|
@@ -165,9 +165,9 @@ class _OCRPredictor:
|
|
|
165
165
|
|
|
166
166
|
def _rectify_crops(
|
|
167
167
|
self,
|
|
168
|
-
crops:
|
|
169
|
-
loc_preds:
|
|
170
|
-
) ->
|
|
168
|
+
crops: list[list[np.ndarray]],
|
|
169
|
+
loc_preds: list[np.ndarray],
|
|
170
|
+
) -> tuple[list[list[np.ndarray]], list[np.ndarray], list[tuple[int, float]]]:
|
|
171
171
|
# Work at a page level
|
|
172
172
|
orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
|
|
173
173
|
rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
|
|
@@ -185,10 +185,10 @@ class _OCRPredictor:
|
|
|
185
185
|
|
|
186
186
|
@staticmethod
|
|
187
187
|
def _process_predictions(
|
|
188
|
-
loc_preds:
|
|
189
|
-
word_preds:
|
|
190
|
-
crop_orientations:
|
|
191
|
-
) ->
|
|
188
|
+
loc_preds: list[np.ndarray],
|
|
189
|
+
word_preds: list[tuple[str, float]],
|
|
190
|
+
crop_orientations: list[dict[str, Any]],
|
|
191
|
+
) -> tuple[list[np.ndarray], list[list[tuple[str, float]]], list[list[dict[str, Any]]]]:
|
|
192
192
|
text_preds = []
|
|
193
193
|
crop_orientation_preds = []
|
|
194
194
|
if len(loc_preds) > 0:
|
|
@@ -205,10 +205,9 @@ class _OCRPredictor:
|
|
|
205
205
|
"""Add a hook to the predictor
|
|
206
206
|
|
|
207
207
|
Args:
|
|
208
|
-
----
|
|
209
208
|
hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
|
|
210
209
|
"""
|
|
211
210
|
self.hooks.append(hook)
|
|
212
211
|
|
|
213
|
-
def list_archs(self) ->
|
|
212
|
+
def list_archs(self) -> dict[str, list[str]]:
|
|
214
213
|
return {"detection_archs": DETECTION_ARCHS, "recognition_archs": RECOGNITION_ARCHS}
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
@@ -24,7 +24,6 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -52,7 +51,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
52
51
|
symmetric_pad: bool = True,
|
|
53
52
|
detect_orientation: bool = False,
|
|
54
53
|
detect_language: bool = False,
|
|
55
|
-
clf_engine_cfg:
|
|
54
|
+
clf_engine_cfg: EngineConfig | None = None,
|
|
56
55
|
**kwargs: Any,
|
|
57
56
|
) -> None:
|
|
58
57
|
self.det_predictor = det_predictor
|
|
@@ -72,7 +71,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
72
71
|
|
|
73
72
|
def __call__(
|
|
74
73
|
self,
|
|
75
|
-
pages:
|
|
74
|
+
pages: list[np.ndarray],
|
|
76
75
|
**kwargs: Any,
|
|
77
76
|
) -> Document:
|
|
78
77
|
# Dimension check
|
|
@@ -147,7 +146,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
147
146
|
boxes,
|
|
148
147
|
objectness_scores,
|
|
149
148
|
text_preds,
|
|
150
|
-
origin_page_shapes,
|
|
149
|
+
origin_page_shapes,
|
|
151
150
|
crop_orientations,
|
|
152
151
|
orientations,
|
|
153
152
|
languages_dict,
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
|
|
@@ -20,36 +20,34 @@ class PreProcessor(NestedObject):
|
|
|
20
20
|
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
output_size: expected size of each page in format (H, W)
|
|
25
24
|
batch_size: the size of page batches
|
|
26
25
|
mean: mean value of the training distribution by channel
|
|
27
26
|
std: standard deviation of the training distribution by channel
|
|
27
|
+
**kwargs: additional arguments for the resizing operation
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
|
-
_children_names:
|
|
30
|
+
_children_names: list[str] = ["resize", "normalize"]
|
|
31
31
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
|
-
output_size:
|
|
34
|
+
output_size: tuple[int, int],
|
|
35
35
|
batch_size: int,
|
|
36
|
-
mean:
|
|
37
|
-
std:
|
|
36
|
+
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
37
|
+
std: tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
38
38
|
**kwargs: Any,
|
|
39
39
|
) -> None:
|
|
40
40
|
self.batch_size = batch_size
|
|
41
41
|
self.resize = Resize(output_size, **kwargs)
|
|
42
42
|
self.normalize = Normalize(mean, std)
|
|
43
43
|
|
|
44
|
-
def batch_inputs(self, samples:
|
|
44
|
+
def batch_inputs(self, samples: list[np.ndarray]) -> list[np.ndarray]:
|
|
45
45
|
"""Gather samples into batches for inference purposes
|
|
46
46
|
|
|
47
47
|
Args:
|
|
48
|
-
----
|
|
49
48
|
samples: list of samples (tf.Tensor)
|
|
50
49
|
|
|
51
50
|
Returns:
|
|
52
|
-
-------
|
|
53
51
|
list of batched samples
|
|
54
52
|
"""
|
|
55
53
|
num_batches = int(math.ceil(len(samples) / self.batch_size))
|
|
@@ -76,15 +74,13 @@ class PreProcessor(NestedObject):
|
|
|
76
74
|
|
|
77
75
|
return x
|
|
78
76
|
|
|
79
|
-
def __call__(self, x:
|
|
77
|
+
def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[np.ndarray]:
|
|
80
78
|
"""Prepare document data for model forwarding
|
|
81
79
|
|
|
82
80
|
Args:
|
|
83
|
-
----
|
|
84
81
|
x: list of images (np.array) or tensors (already resized and batched)
|
|
85
82
|
|
|
86
83
|
Returns:
|
|
87
|
-
-------
|
|
88
84
|
list of page batches
|
|
89
85
|
"""
|
|
90
86
|
# Input type check
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
7
|
from itertools import groupby
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from scipy.special import softmax
|
|
@@ -17,7 +17,7 @@ from ..core import RecognitionPostProcessor
|
|
|
17
17
|
|
|
18
18
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
19
19
|
|
|
20
|
-
default_cfgs:
|
|
20
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
21
21
|
"crnn_vgg16_bn": {
|
|
22
22
|
"mean": (0.694, 0.695, 0.693),
|
|
23
23
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -49,7 +49,6 @@ class CRNNPostProcessor(RecognitionPostProcessor):
|
|
|
49
49
|
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
50
50
|
|
|
51
51
|
Args:
|
|
52
|
-
----
|
|
53
52
|
vocab: string containing the ordered sequence of supported characters
|
|
54
53
|
"""
|
|
55
54
|
|
|
@@ -69,13 +68,11 @@ class CRNNPostProcessor(RecognitionPostProcessor):
|
|
|
69
68
|
<https://github.com/githubharald/CTCDecoder>`_.
|
|
70
69
|
|
|
71
70
|
Args:
|
|
72
|
-
----
|
|
73
71
|
logits: model output, shape: N x T x C
|
|
74
72
|
vocab: vocabulary to use
|
|
75
73
|
blank: index of blank label
|
|
76
74
|
|
|
77
75
|
Returns:
|
|
78
|
-
-------
|
|
79
76
|
A list of tuples: (word, confidence)
|
|
80
77
|
"""
|
|
81
78
|
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
|
|
@@ -94,11 +91,9 @@ class CRNNPostProcessor(RecognitionPostProcessor):
|
|
|
94
91
|
with label_to_idx mapping dictionnary
|
|
95
92
|
|
|
96
93
|
Args:
|
|
97
|
-
----
|
|
98
94
|
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
99
95
|
|
|
100
96
|
Returns:
|
|
101
|
-
-------
|
|
102
97
|
A tuple of 2 lists: a list of str (words) and a list of float (probs)
|
|
103
98
|
|
|
104
99
|
"""
|
|
@@ -110,7 +105,6 @@ class CRNN(Engine):
|
|
|
110
105
|
"""CRNN Onnx loader
|
|
111
106
|
|
|
112
107
|
Args:
|
|
113
|
-
----
|
|
114
108
|
model_path: path or url to onnx model file
|
|
115
109
|
vocab: vocabulary used for encoding
|
|
116
110
|
engine_cfg: configuration for the inference engine
|
|
@@ -118,14 +112,14 @@ class CRNN(Engine):
|
|
|
118
112
|
**kwargs: additional arguments to be passed to `Engine`
|
|
119
113
|
"""
|
|
120
114
|
|
|
121
|
-
_children_names:
|
|
115
|
+
_children_names: list[str] = ["postprocessor"]
|
|
122
116
|
|
|
123
117
|
def __init__(
|
|
124
118
|
self,
|
|
125
119
|
model_path: str,
|
|
126
120
|
vocab: str,
|
|
127
|
-
engine_cfg:
|
|
128
|
-
cfg:
|
|
121
|
+
engine_cfg: EngineConfig | None = None,
|
|
122
|
+
cfg: dict[str, Any] | None = None,
|
|
129
123
|
**kwargs: Any,
|
|
130
124
|
) -> None:
|
|
131
125
|
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
@@ -139,10 +133,10 @@ class CRNN(Engine):
|
|
|
139
133
|
self,
|
|
140
134
|
x: np.ndarray,
|
|
141
135
|
return_model_output: bool = False,
|
|
142
|
-
) ->
|
|
136
|
+
) -> dict[str, Any]:
|
|
143
137
|
logits = self.run(x)
|
|
144
138
|
|
|
145
|
-
out:
|
|
139
|
+
out: dict[str, Any] = {}
|
|
146
140
|
if return_model_output:
|
|
147
141
|
out["out_map"] = logits
|
|
148
142
|
|
|
@@ -156,7 +150,7 @@ def _crnn(
|
|
|
156
150
|
arch: str,
|
|
157
151
|
model_path: str,
|
|
158
152
|
load_in_8_bit: bool = False,
|
|
159
|
-
engine_cfg:
|
|
153
|
+
engine_cfg: EngineConfig | None = None,
|
|
160
154
|
**kwargs: Any,
|
|
161
155
|
) -> CRNN:
|
|
162
156
|
kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
|
|
@@ -174,7 +168,7 @@ def _crnn(
|
|
|
174
168
|
def crnn_vgg16_bn(
|
|
175
169
|
model_path: str = default_cfgs["crnn_vgg16_bn"]["url"],
|
|
176
170
|
load_in_8_bit: bool = False,
|
|
177
|
-
engine_cfg:
|
|
171
|
+
engine_cfg: EngineConfig | None = None,
|
|
178
172
|
**kwargs: Any,
|
|
179
173
|
) -> CRNN:
|
|
180
174
|
"""CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
@@ -187,14 +181,12 @@ def crnn_vgg16_bn(
|
|
|
187
181
|
>>> out = model(input_tensor)
|
|
188
182
|
|
|
189
183
|
Args:
|
|
190
|
-
----
|
|
191
184
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
192
185
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
193
186
|
engine_cfg: configuration for the inference engine
|
|
194
187
|
**kwargs: keyword arguments of the CRNN architecture
|
|
195
188
|
|
|
196
189
|
Returns:
|
|
197
|
-
-------
|
|
198
190
|
text recognition architecture
|
|
199
191
|
"""
|
|
200
192
|
return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -203,7 +195,7 @@ def crnn_vgg16_bn(
|
|
|
203
195
|
def crnn_mobilenet_v3_small(
|
|
204
196
|
model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"],
|
|
205
197
|
load_in_8_bit: bool = False,
|
|
206
|
-
engine_cfg:
|
|
198
|
+
engine_cfg: EngineConfig | None = None,
|
|
207
199
|
**kwargs: Any,
|
|
208
200
|
) -> CRNN:
|
|
209
201
|
"""CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
@@ -216,14 +208,12 @@ def crnn_mobilenet_v3_small(
|
|
|
216
208
|
>>> out = model(input_tensor)
|
|
217
209
|
|
|
218
210
|
Args:
|
|
219
|
-
----
|
|
220
211
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
221
212
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
222
213
|
engine_cfg: configuration for the inference engine
|
|
223
214
|
**kwargs: keyword arguments of the CRNN architecture
|
|
224
215
|
|
|
225
216
|
Returns:
|
|
226
|
-
-------
|
|
227
217
|
text recognition architecture
|
|
228
218
|
"""
|
|
229
219
|
return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -232,7 +222,7 @@ def crnn_mobilenet_v3_small(
|
|
|
232
222
|
def crnn_mobilenet_v3_large(
|
|
233
223
|
model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"],
|
|
234
224
|
load_in_8_bit: bool = False,
|
|
235
|
-
engine_cfg:
|
|
225
|
+
engine_cfg: EngineConfig | None = None,
|
|
236
226
|
**kwargs: Any,
|
|
237
227
|
) -> CRNN:
|
|
238
228
|
"""CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
@@ -245,14 +235,12 @@ def crnn_mobilenet_v3_large(
|
|
|
245
235
|
>>> out = model(input_tensor)
|
|
246
236
|
|
|
247
237
|
Args:
|
|
248
|
-
----
|
|
249
238
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
250
239
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
251
240
|
engine_cfg: configuration for the inference engine
|
|
252
241
|
**kwargs: keyword arguments of the CRNN architecture
|
|
253
242
|
|
|
254
243
|
Returns:
|
|
255
|
-
-------
|
|
256
244
|
text recognition architecture
|
|
257
245
|
"""
|
|
258
246
|
return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from scipy.special import softmax
|
|
@@ -17,7 +17,7 @@ from ..core import RecognitionPostProcessor
|
|
|
17
17
|
__all__ = ["MASTER", "master"]
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
default_cfgs:
|
|
20
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
21
21
|
"master": {
|
|
22
22
|
"mean": (0.694, 0.695, 0.693),
|
|
23
23
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -33,7 +33,6 @@ class MASTER(Engine):
|
|
|
33
33
|
"""MASTER Onnx loader
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
|
-
----
|
|
37
36
|
model_path: path or url to onnx model file
|
|
38
37
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
39
38
|
engine_cfg: configuration for the inference engine
|
|
@@ -45,8 +44,8 @@ class MASTER(Engine):
|
|
|
45
44
|
self,
|
|
46
45
|
model_path: str,
|
|
47
46
|
vocab: str,
|
|
48
|
-
engine_cfg:
|
|
49
|
-
cfg:
|
|
47
|
+
engine_cfg: EngineConfig | None = None,
|
|
48
|
+
cfg: dict[str, Any] | None = None,
|
|
50
49
|
**kwargs: Any,
|
|
51
50
|
) -> None:
|
|
52
51
|
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
@@ -60,20 +59,18 @@ class MASTER(Engine):
|
|
|
60
59
|
self,
|
|
61
60
|
x: np.ndarray,
|
|
62
61
|
return_model_output: bool = False,
|
|
63
|
-
) ->
|
|
62
|
+
) -> dict[str, Any]:
|
|
64
63
|
"""Call function
|
|
65
64
|
|
|
66
65
|
Args:
|
|
67
|
-
----
|
|
68
66
|
x: images
|
|
69
67
|
return_model_output: if True, return logits
|
|
70
68
|
|
|
71
69
|
Returns:
|
|
72
|
-
-------
|
|
73
70
|
A dictionnary containing eventually logits and predictions.
|
|
74
71
|
"""
|
|
75
72
|
logits = self.run(x)
|
|
76
|
-
out:
|
|
73
|
+
out: dict[str, Any] = {}
|
|
77
74
|
|
|
78
75
|
if return_model_output:
|
|
79
76
|
out["out_map"] = logits
|
|
@@ -87,7 +84,6 @@ class MASTERPostProcessor(RecognitionPostProcessor):
|
|
|
87
84
|
"""Post-processor for the MASTER model
|
|
88
85
|
|
|
89
86
|
Args:
|
|
90
|
-
----
|
|
91
87
|
vocab: string containing the ordered sequence of supported characters
|
|
92
88
|
"""
|
|
93
89
|
|
|
@@ -98,7 +94,7 @@ class MASTERPostProcessor(RecognitionPostProcessor):
|
|
|
98
94
|
super().__init__(vocab)
|
|
99
95
|
self._embedding = list(vocab) + ["<eos>"] + ["<sos>"] + ["<pad>"]
|
|
100
96
|
|
|
101
|
-
def __call__(self, logits: np.ndarray) ->
|
|
97
|
+
def __call__(self, logits: np.ndarray) -> list[tuple[str, float]]:
|
|
102
98
|
# compute pred with argmax for attention models
|
|
103
99
|
out_idxs = np.argmax(logits, axis=-1)
|
|
104
100
|
# N x L
|
|
@@ -117,7 +113,7 @@ def _master(
|
|
|
117
113
|
arch: str,
|
|
118
114
|
model_path: str,
|
|
119
115
|
load_in_8_bit: bool = False,
|
|
120
|
-
engine_cfg:
|
|
116
|
+
engine_cfg: EngineConfig | None = None,
|
|
121
117
|
**kwargs: Any,
|
|
122
118
|
) -> MASTER:
|
|
123
119
|
# Patch the config
|
|
@@ -135,7 +131,7 @@ def _master(
|
|
|
135
131
|
def master(
|
|
136
132
|
model_path: str = default_cfgs["master"]["url"],
|
|
137
133
|
load_in_8_bit: bool = False,
|
|
138
|
-
engine_cfg:
|
|
134
|
+
engine_cfg: EngineConfig | None = None,
|
|
139
135
|
**kwargs: Any,
|
|
140
136
|
) -> MASTER:
|
|
141
137
|
"""MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
@@ -147,14 +143,12 @@ def master(
|
|
|
147
143
|
>>> out = model(input_tensor)
|
|
148
144
|
|
|
149
145
|
Args:
|
|
150
|
-
----
|
|
151
146
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
152
147
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
153
148
|
engine_cfg: configuration for the inference engine
|
|
154
149
|
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
155
150
|
|
|
156
151
|
Returns:
|
|
157
|
-
-------
|
|
158
152
|
text recognition architecture
|
|
159
153
|
"""
|
|
160
154
|
return _master("master", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from scipy.special import softmax
|
|
@@ -16,7 +16,7 @@ from ..core import RecognitionPostProcessor
|
|
|
16
16
|
|
|
17
17
|
__all__ = ["PARSeq", "parseq"]
|
|
18
18
|
|
|
19
|
-
default_cfgs:
|
|
19
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
20
|
"parseq": {
|
|
21
21
|
"mean": (0.694, 0.695, 0.693),
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -32,7 +32,6 @@ class PARSeq(Engine):
|
|
|
32
32
|
"""PARSeq Onnx loader
|
|
33
33
|
|
|
34
34
|
Args:
|
|
35
|
-
----
|
|
36
35
|
model_path: path to onnx model file
|
|
37
36
|
vocab: vocabulary used for encoding
|
|
38
37
|
engine_cfg: configuration for the inference engine
|
|
@@ -44,8 +43,8 @@ class PARSeq(Engine):
|
|
|
44
43
|
self,
|
|
45
44
|
model_path: str,
|
|
46
45
|
vocab: str,
|
|
47
|
-
engine_cfg:
|
|
48
|
-
cfg:
|
|
46
|
+
engine_cfg: EngineConfig | None = None,
|
|
47
|
+
cfg: dict[str, Any] | None = None,
|
|
49
48
|
**kwargs: Any,
|
|
50
49
|
) -> None:
|
|
51
50
|
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
@@ -59,9 +58,9 @@ class PARSeq(Engine):
|
|
|
59
58
|
self,
|
|
60
59
|
x: np.ndarray,
|
|
61
60
|
return_model_output: bool = False,
|
|
62
|
-
) ->
|
|
61
|
+
) -> dict[str, Any]:
|
|
63
62
|
logits = self.run(x)
|
|
64
|
-
out:
|
|
63
|
+
out: dict[str, Any] = {}
|
|
65
64
|
|
|
66
65
|
if return_model_output:
|
|
67
66
|
out["out_map"] = logits
|
|
@@ -74,7 +73,6 @@ class PARSeqPostProcessor(RecognitionPostProcessor):
|
|
|
74
73
|
"""Post processor for PARSeq architecture
|
|
75
74
|
|
|
76
75
|
Args:
|
|
77
|
-
----
|
|
78
76
|
vocab: string containing the ordered sequence of supported characters
|
|
79
77
|
"""
|
|
80
78
|
|
|
@@ -106,7 +104,7 @@ def _parseq(
|
|
|
106
104
|
arch: str,
|
|
107
105
|
model_path: str,
|
|
108
106
|
load_in_8_bit: bool = False,
|
|
109
|
-
engine_cfg:
|
|
107
|
+
engine_cfg: EngineConfig | None = None,
|
|
110
108
|
**kwargs: Any,
|
|
111
109
|
) -> PARSeq:
|
|
112
110
|
# Patch the config
|
|
@@ -125,7 +123,7 @@ def _parseq(
|
|
|
125
123
|
def parseq(
|
|
126
124
|
model_path: str = default_cfgs["parseq"]["url"],
|
|
127
125
|
load_in_8_bit: bool = False,
|
|
128
|
-
engine_cfg:
|
|
126
|
+
engine_cfg: EngineConfig | None = None,
|
|
129
127
|
**kwargs: Any,
|
|
130
128
|
) -> PARSeq:
|
|
131
129
|
"""PARSeq architecture from
|
|
@@ -138,14 +136,12 @@ def parseq(
|
|
|
138
136
|
>>> out = model(input_tensor)
|
|
139
137
|
|
|
140
138
|
Args:
|
|
141
|
-
----
|
|
142
139
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
143
140
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
144
141
|
engine_cfg: configuration for the inference engine
|
|
145
142
|
**kwargs: keyword arguments of the PARSeq architecture
|
|
146
143
|
|
|
147
144
|
Returns:
|
|
148
|
-
-------
|
|
149
145
|
text recognition architecture
|
|
150
146
|
"""
|
|
151
147
|
return _parseq("parseq", model_path, load_in_8_bit, engine_cfg, **kwargs)
|