onnxtr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/__init__.py +2 -0
- onnxtr/contrib/__init__.py +0 -0
- onnxtr/contrib/artefacts.py +131 -0
- onnxtr/contrib/base.py +105 -0
- onnxtr/file_utils.py +33 -0
- onnxtr/io/__init__.py +5 -0
- onnxtr/io/elements.py +455 -0
- onnxtr/io/html.py +28 -0
- onnxtr/io/image.py +56 -0
- onnxtr/io/pdf.py +42 -0
- onnxtr/io/reader.py +85 -0
- onnxtr/models/__init__.py +4 -0
- onnxtr/models/_utils.py +141 -0
- onnxtr/models/builder.py +355 -0
- onnxtr/models/classification/__init__.py +2 -0
- onnxtr/models/classification/models/__init__.py +1 -0
- onnxtr/models/classification/models/mobilenet.py +120 -0
- onnxtr/models/classification/predictor/__init__.py +1 -0
- onnxtr/models/classification/predictor/base.py +57 -0
- onnxtr/models/classification/zoo.py +76 -0
- onnxtr/models/detection/__init__.py +2 -0
- onnxtr/models/detection/core.py +101 -0
- onnxtr/models/detection/models/__init__.py +3 -0
- onnxtr/models/detection/models/differentiable_binarization.py +159 -0
- onnxtr/models/detection/models/fast.py +160 -0
- onnxtr/models/detection/models/linknet.py +160 -0
- onnxtr/models/detection/postprocessor/__init__.py +0 -0
- onnxtr/models/detection/postprocessor/base.py +144 -0
- onnxtr/models/detection/predictor/__init__.py +1 -0
- onnxtr/models/detection/predictor/base.py +54 -0
- onnxtr/models/detection/zoo.py +73 -0
- onnxtr/models/engine.py +50 -0
- onnxtr/models/predictor/__init__.py +1 -0
- onnxtr/models/predictor/base.py +175 -0
- onnxtr/models/predictor/predictor.py +145 -0
- onnxtr/models/preprocessor/__init__.py +1 -0
- onnxtr/models/preprocessor/base.py +118 -0
- onnxtr/models/recognition/__init__.py +2 -0
- onnxtr/models/recognition/core.py +28 -0
- onnxtr/models/recognition/models/__init__.py +5 -0
- onnxtr/models/recognition/models/crnn.py +226 -0
- onnxtr/models/recognition/models/master.py +145 -0
- onnxtr/models/recognition/models/parseq.py +134 -0
- onnxtr/models/recognition/models/sar.py +134 -0
- onnxtr/models/recognition/models/vitstr.py +166 -0
- onnxtr/models/recognition/predictor/__init__.py +1 -0
- onnxtr/models/recognition/predictor/_utils.py +86 -0
- onnxtr/models/recognition/predictor/base.py +79 -0
- onnxtr/models/recognition/utils.py +89 -0
- onnxtr/models/recognition/zoo.py +69 -0
- onnxtr/models/zoo.py +114 -0
- onnxtr/transforms/__init__.py +1 -0
- onnxtr/transforms/base.py +112 -0
- onnxtr/utils/__init__.py +4 -0
- onnxtr/utils/common_types.py +18 -0
- onnxtr/utils/data.py +126 -0
- onnxtr/utils/fonts.py +41 -0
- onnxtr/utils/geometry.py +498 -0
- onnxtr/utils/multithreading.py +50 -0
- onnxtr/utils/reconstitution.py +70 -0
- onnxtr/utils/repr.py +64 -0
- onnxtr/utils/visualization.py +291 -0
- onnxtr/utils/vocabs.py +71 -0
- onnxtr/version.py +1 -0
- onnxtr-0.1.0.dist-info/LICENSE +201 -0
- onnxtr-0.1.0.dist-info/METADATA +481 -0
- onnxtr-0.1.0.dist-info/RECORD +70 -0
- onnxtr-0.1.0.dist-info/WHEEL +5 -0
- onnxtr-0.1.0.dist-info/top_level.txt +2 -0
- onnxtr-0.1.0.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
from typing import Any, List, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from onnxtr.transforms import Normalize, Resize
|
|
12
|
+
from onnxtr.utils.geometry import shape_translate
|
|
13
|
+
from onnxtr.utils.multithreading import multithread_exec
|
|
14
|
+
from onnxtr.utils.repr import NestedObject
|
|
15
|
+
|
|
16
|
+
__all__ = ["PreProcessor"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PreProcessor(NestedObject):
|
|
20
|
+
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
----
|
|
24
|
+
output_size: expected size of each page in format (H, W)
|
|
25
|
+
batch_size: the size of page batches
|
|
26
|
+
mean: mean value of the training distribution by channel
|
|
27
|
+
std: standard deviation of the training distribution by channel
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
_children_names: List[str] = ["resize", "normalize"]
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
output_size: Tuple[int, int],
|
|
35
|
+
batch_size: int,
|
|
36
|
+
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
37
|
+
std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
38
|
+
**kwargs: Any,
|
|
39
|
+
) -> None:
|
|
40
|
+
self.batch_size = batch_size
|
|
41
|
+
self.resize = Resize(output_size, **kwargs)
|
|
42
|
+
self.normalize = Normalize(mean, std)
|
|
43
|
+
|
|
44
|
+
def batch_inputs(self, samples: List[np.ndarray]) -> List[np.ndarray]:
|
|
45
|
+
"""Gather samples into batches for inference purposes
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
----
|
|
49
|
+
samples: list of samples (tf.Tensor)
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
-------
|
|
53
|
+
list of batched samples
|
|
54
|
+
"""
|
|
55
|
+
num_batches = int(math.ceil(len(samples) / self.batch_size))
|
|
56
|
+
batches = [
|
|
57
|
+
np.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0)
|
|
58
|
+
for idx in range(int(num_batches))
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
return batches
|
|
62
|
+
|
|
63
|
+
def sample_transforms(self, x: np.ndarray) -> np.ndarray:
|
|
64
|
+
if x.ndim != 3:
|
|
65
|
+
raise AssertionError("expected list of 3D Tensors")
|
|
66
|
+
if isinstance(x, np.ndarray):
|
|
67
|
+
if x.dtype not in (np.uint8, np.float32):
|
|
68
|
+
raise TypeError("unsupported data type for numpy.ndarray")
|
|
69
|
+
x = shape_translate(x, "HWC")
|
|
70
|
+
# Data type & 255 division
|
|
71
|
+
if x.dtype == np.uint8:
|
|
72
|
+
x = x.astype(np.float32) / 255.0
|
|
73
|
+
# Resizing
|
|
74
|
+
x = self.resize(x)
|
|
75
|
+
|
|
76
|
+
return x
|
|
77
|
+
|
|
78
|
+
def __call__(self, x: Union[np.ndarray, List[np.ndarray]]) -> List[np.ndarray]:
|
|
79
|
+
"""Prepare document data for model forwarding
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
----
|
|
83
|
+
x: list of images (np.array) or tensors (already resized and batched)
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
-------
|
|
87
|
+
list of page batches
|
|
88
|
+
"""
|
|
89
|
+
# Input type check
|
|
90
|
+
if isinstance(x, np.ndarray):
|
|
91
|
+
if x.ndim != 4:
|
|
92
|
+
raise AssertionError("expected 4D Tensor")
|
|
93
|
+
if isinstance(x, np.ndarray):
|
|
94
|
+
if x.dtype not in (np.uint8, np.float32):
|
|
95
|
+
raise TypeError("unsupported data type for numpy.ndarray")
|
|
96
|
+
x = shape_translate(x, "BHWC")
|
|
97
|
+
|
|
98
|
+
# Data type & 255 division
|
|
99
|
+
if x.dtype == np.uint8:
|
|
100
|
+
x = x.astype(np.float32) / 255.0
|
|
101
|
+
# Resizing
|
|
102
|
+
if (x.shape[1], x.shape[2]) != self.resize.output_size:
|
|
103
|
+
x = np.array([self.resize(sample) for sample in x])
|
|
104
|
+
|
|
105
|
+
batches = [x]
|
|
106
|
+
|
|
107
|
+
elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
|
|
108
|
+
# Sample transform (to tensor, resize)
|
|
109
|
+
samples = list(multithread_exec(self.sample_transforms, x))
|
|
110
|
+
# Batching
|
|
111
|
+
batches = self.batch_inputs(samples)
|
|
112
|
+
else:
|
|
113
|
+
raise TypeError(f"invalid input type: {type(x)}")
|
|
114
|
+
|
|
115
|
+
# Batch transforms (normalize)
|
|
116
|
+
batches = list(multithread_exec(self.normalize, batches))
|
|
117
|
+
|
|
118
|
+
return batches
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from onnxtr.utils.repr import NestedObject
|
|
8
|
+
|
|
9
|
+
__all__ = ["RecognitionPostProcessor"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RecognitionPostProcessor(NestedObject):
|
|
13
|
+
"""Abstract class to postprocess the raw output of the model
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
----
|
|
17
|
+
vocab: string containing the ordered sequence of supported characters
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
vocab: str,
|
|
23
|
+
) -> None:
|
|
24
|
+
self.vocab = vocab
|
|
25
|
+
self._embedding = list(self.vocab) + ["<eos>"]
|
|
26
|
+
|
|
27
|
+
def extra_repr(self) -> str:
|
|
28
|
+
return f"vocab_size={len(self.vocab)}"
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from itertools import groupby
|
|
8
|
+
from typing import Any, Dict, List, Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy.special import softmax
|
|
12
|
+
|
|
13
|
+
from onnxtr.utils import VOCABS
|
|
14
|
+
|
|
15
|
+
from ...engine import Engine
|
|
16
|
+
from ..core import RecognitionPostProcessor
|
|
17
|
+
|
|
18
|
+
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
19
|
+
|
|
20
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
21
|
+
"crnn_vgg16_bn": {
|
|
22
|
+
"mean": (0.694, 0.695, 0.693),
|
|
23
|
+
"std": (0.299, 0.296, 0.301),
|
|
24
|
+
"input_shape": (3, 32, 128),
|
|
25
|
+
"vocab": VOCABS["legacy_french"],
|
|
26
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_vgg16_bn-662979cc.onnx",
|
|
27
|
+
},
|
|
28
|
+
"crnn_mobilenet_v3_small": {
|
|
29
|
+
"mean": (0.694, 0.695, 0.693),
|
|
30
|
+
"std": (0.299, 0.296, 0.301),
|
|
31
|
+
"input_shape": (3, 32, 128),
|
|
32
|
+
"vocab": VOCABS["french"],
|
|
33
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_mobilenet_v3_small-bded4d49.onnx",
|
|
34
|
+
},
|
|
35
|
+
"crnn_mobilenet_v3_large": {
|
|
36
|
+
"mean": (0.694, 0.695, 0.693),
|
|
37
|
+
"std": (0.299, 0.296, 0.301),
|
|
38
|
+
"input_shape": (3, 32, 128),
|
|
39
|
+
"vocab": VOCABS["french"],
|
|
40
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_mobilenet_v3_large-d42e8185.onnx",
|
|
41
|
+
},
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CRNNPostProcessor(RecognitionPostProcessor):
|
|
46
|
+
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
----
|
|
50
|
+
vocab: string containing the ordered sequence of supported characters
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, vocab):
|
|
54
|
+
self.vocab = vocab
|
|
55
|
+
|
|
56
|
+
def decode_sequence(self, sequence, vocab):
|
|
57
|
+
return "".join([vocab[int(char)] for char in sequence])
|
|
58
|
+
|
|
59
|
+
def ctc_best_path(
|
|
60
|
+
self,
|
|
61
|
+
logits,
|
|
62
|
+
vocab,
|
|
63
|
+
blank=0,
|
|
64
|
+
):
|
|
65
|
+
"""Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
|
|
66
|
+
<https://github.com/githubharald/CTCDecoder>`_.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
----
|
|
70
|
+
logits: model output, shape: N x T x C
|
|
71
|
+
vocab: vocabulary to use
|
|
72
|
+
blank: index of blank label
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
-------
|
|
76
|
+
A list of tuples: (word, confidence)
|
|
77
|
+
"""
|
|
78
|
+
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
|
|
79
|
+
probs = softmax(logits, axis=-1).max(axis=-1).min(axis=1)
|
|
80
|
+
|
|
81
|
+
# collapse best path (using itertools.groupby), map to chars, join char list to string
|
|
82
|
+
words = [
|
|
83
|
+
self.decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
|
|
84
|
+
for seq in np.argmax(logits, axis=-1)
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
return list(zip(words, probs.astype(float).tolist()))
|
|
88
|
+
|
|
89
|
+
def __call__(self, logits):
|
|
90
|
+
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
91
|
+
with label_to_idx mapping dictionnary
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
----
|
|
95
|
+
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
-------
|
|
99
|
+
A tuple of 2 lists: a list of str (words) and a list of float (probs)
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
# Decode CTC
|
|
103
|
+
return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class CRNN(Engine):
|
|
107
|
+
"""CRNN Onnx loader
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
----
|
|
111
|
+
model_path: path or url to onnx model file
|
|
112
|
+
vocab: vocabulary used for encoding
|
|
113
|
+
cfg: configuration dictionary
|
|
114
|
+
**kwargs: additional arguments to be passed to `Engine`
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
_children_names: List[str] = ["postprocessor"]
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
model_path: str,
|
|
122
|
+
vocab: str,
|
|
123
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
124
|
+
**kwargs: Any,
|
|
125
|
+
) -> None:
|
|
126
|
+
super().__init__(url=model_path, **kwargs)
|
|
127
|
+
self.vocab = vocab
|
|
128
|
+
self.cfg = cfg
|
|
129
|
+
self.postprocessor = CRNNPostProcessor(self.vocab)
|
|
130
|
+
|
|
131
|
+
def __call__(
|
|
132
|
+
self,
|
|
133
|
+
x: np.ndarray,
|
|
134
|
+
return_model_output: bool = False,
|
|
135
|
+
) -> Dict[str, Any]:
|
|
136
|
+
logits = self.run(x)
|
|
137
|
+
|
|
138
|
+
out: Dict[str, Any] = {}
|
|
139
|
+
if return_model_output:
|
|
140
|
+
out["out_map"] = logits
|
|
141
|
+
|
|
142
|
+
# Post-process
|
|
143
|
+
out["preds"] = self.postprocessor(logits)
|
|
144
|
+
|
|
145
|
+
return out
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _crnn(
|
|
149
|
+
arch: str,
|
|
150
|
+
model_path: str,
|
|
151
|
+
**kwargs: Any,
|
|
152
|
+
) -> CRNN:
|
|
153
|
+
kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
|
|
154
|
+
|
|
155
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
156
|
+
_cfg["vocab"] = kwargs["vocab"]
|
|
157
|
+
_cfg["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
|
|
158
|
+
|
|
159
|
+
# Build the model
|
|
160
|
+
return CRNN(model_path, cfg=_cfg, **kwargs)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def crnn_vgg16_bn(model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], **kwargs: Any) -> CRNN:
|
|
164
|
+
"""CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
165
|
+
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
166
|
+
|
|
167
|
+
>>> import numpy as np
|
|
168
|
+
>>> from onnxtr.models import crnn_vgg16_bn
|
|
169
|
+
>>> model = crnn_vgg16_bn()
|
|
170
|
+
>>> input_tensor = np.random.rand(1, 3, 32, 128)
|
|
171
|
+
>>> out = model(input_tensor)
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
----
|
|
175
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
176
|
+
**kwargs: keyword arguments of the CRNN architecture
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
-------
|
|
180
|
+
text recognition architecture
|
|
181
|
+
"""
|
|
182
|
+
return _crnn("crnn_vgg16_bn", model_path, **kwargs)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def crnn_mobilenet_v3_small(model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"], **kwargs: Any) -> CRNN:
|
|
186
|
+
"""CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
187
|
+
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
188
|
+
|
|
189
|
+
>>> import numpy as np
|
|
190
|
+
>>> from onnxtr.models import crnn_mobilenet_v3_small
|
|
191
|
+
>>> model = crnn_mobilenet_v3_small()
|
|
192
|
+
>>> input_tensor = np.random.rand(1, 3, 32, 128)
|
|
193
|
+
>>> out = model(input_tensor)
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
----
|
|
197
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
198
|
+
**kwargs: keyword arguments of the CRNN architecture
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
-------
|
|
202
|
+
text recognition architecture
|
|
203
|
+
"""
|
|
204
|
+
return _crnn("crnn_mobilenet_v3_small", model_path, **kwargs)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def crnn_mobilenet_v3_large(model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"], **kwargs: Any) -> CRNN:
|
|
208
|
+
"""CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
209
|
+
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
210
|
+
|
|
211
|
+
>>> import numpy as np
|
|
212
|
+
>>> from onnxtr.models import crnn_mobilenet_v3_large
|
|
213
|
+
>>> model = crnn_mobilenet_v3_large()
|
|
214
|
+
>>> input_tensor = np.random.rand(1, 3, 32, 128)
|
|
215
|
+
>>> out = model(input_tensor)
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
----
|
|
219
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
220
|
+
**kwargs: keyword arguments of the CRNN architecture
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
-------
|
|
224
|
+
text recognition architecture
|
|
225
|
+
"""
|
|
226
|
+
return _crnn("crnn_mobilenet_v3_large", model_path, **kwargs)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy.special import softmax
|
|
11
|
+
|
|
12
|
+
from onnxtr.utils import VOCABS
|
|
13
|
+
|
|
14
|
+
from ...engine import Engine
|
|
15
|
+
from ..core import RecognitionPostProcessor
|
|
16
|
+
|
|
17
|
+
__all__ = ["MASTER", "master"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
21
|
+
"master": {
|
|
22
|
+
"mean": (0.694, 0.695, 0.693),
|
|
23
|
+
"std": (0.299, 0.296, 0.301),
|
|
24
|
+
"input_shape": (3, 32, 128),
|
|
25
|
+
"vocab": VOCABS["french"],
|
|
26
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/master-b1287fcd.onnx",
|
|
27
|
+
},
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MASTER(Engine):
|
|
32
|
+
"""MASTER Onnx loader
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
----
|
|
36
|
+
model_path: path or url to onnx model file
|
|
37
|
+
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
38
|
+
cfg: dictionary containing information about the model
|
|
39
|
+
**kwargs: additional arguments to be passed to `Engine`
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
model_path: str,
|
|
45
|
+
vocab: str,
|
|
46
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
47
|
+
**kwargs: Any,
|
|
48
|
+
) -> None:
|
|
49
|
+
super().__init__(url=model_path, **kwargs)
|
|
50
|
+
|
|
51
|
+
self.vocab = vocab
|
|
52
|
+
self.cfg = cfg
|
|
53
|
+
self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
|
|
54
|
+
|
|
55
|
+
def __call__(
|
|
56
|
+
self,
|
|
57
|
+
x: np.ndarray,
|
|
58
|
+
return_model_output: bool = False,
|
|
59
|
+
) -> Dict[str, Any]:
|
|
60
|
+
"""Call function
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
----
|
|
64
|
+
x: images
|
|
65
|
+
return_model_output: if True, return logits
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
-------
|
|
69
|
+
A dictionnary containing eventually logits and predictions.
|
|
70
|
+
"""
|
|
71
|
+
logits = self.run(x)
|
|
72
|
+
out: Dict[str, Any] = {}
|
|
73
|
+
|
|
74
|
+
if return_model_output:
|
|
75
|
+
out["out_map"] = logits
|
|
76
|
+
|
|
77
|
+
out["preds"] = self.postprocessor(logits)
|
|
78
|
+
|
|
79
|
+
return out
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class MASTERPostProcessor(RecognitionPostProcessor):
|
|
83
|
+
"""Post-processor for the MASTER model
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
----
|
|
87
|
+
vocab: string containing the ordered sequence of supported characters
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
vocab: str,
|
|
93
|
+
) -> None:
|
|
94
|
+
super().__init__(vocab)
|
|
95
|
+
self._embedding = list(vocab) + ["<eos>"] + ["<sos>"] + ["<pad>"]
|
|
96
|
+
|
|
97
|
+
def __call__(self, logits: np.ndarray) -> List[Tuple[str, float]]:
|
|
98
|
+
# compute pred with argmax for attention models
|
|
99
|
+
out_idxs = np.argmax(logits, axis=-1)
|
|
100
|
+
# N x L
|
|
101
|
+
probs = np.take_along_axis(softmax(logits, axis=-1), out_idxs[..., None], axis=-1).squeeze(-1)
|
|
102
|
+
# Take the minimum confidence of the sequence
|
|
103
|
+
probs = np.min(probs, axis=1)
|
|
104
|
+
|
|
105
|
+
word_values = [
|
|
106
|
+
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0] for encoded_seq in out_idxs
|
|
107
|
+
]
|
|
108
|
+
|
|
109
|
+
return list(zip(word_values, np.clip(probs, 0, 1).astype(float).tolist()))
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _master(
|
|
113
|
+
arch: str,
|
|
114
|
+
model_path: str,
|
|
115
|
+
**kwargs: Any,
|
|
116
|
+
) -> MASTER:
|
|
117
|
+
# Patch the config
|
|
118
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
119
|
+
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
|
|
120
|
+
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
121
|
+
|
|
122
|
+
kwargs["vocab"] = _cfg["vocab"]
|
|
123
|
+
|
|
124
|
+
return MASTER(model_path, cfg=_cfg, **kwargs)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def master(model_path: str = default_cfgs["master"]["url"], **kwargs: Any) -> MASTER:
|
|
128
|
+
"""MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
129
|
+
|
|
130
|
+
>>> import numpy as np
|
|
131
|
+
>>> from onnxtr.models import master
|
|
132
|
+
>>> model = master()
|
|
133
|
+
>>> input_tensor = np.random.rand(1, 3, 32, 128)
|
|
134
|
+
>>> out = model(input_tensor)
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
----
|
|
138
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
139
|
+
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
-------
|
|
143
|
+
text recognition architecture
|
|
144
|
+
"""
|
|
145
|
+
return _master("master", model_path, **kwargs)
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy.special import softmax
|
|
11
|
+
|
|
12
|
+
from onnxtr.utils import VOCABS
|
|
13
|
+
|
|
14
|
+
from ...engine import Engine
|
|
15
|
+
from ..core import RecognitionPostProcessor
|
|
16
|
+
|
|
17
|
+
__all__ = ["PARSeq", "parseq"]
|
|
18
|
+
|
|
19
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
20
|
+
"parseq": {
|
|
21
|
+
"mean": (0.694, 0.695, 0.693),
|
|
22
|
+
"std": (0.299, 0.296, 0.301),
|
|
23
|
+
"input_shape": (3, 32, 128),
|
|
24
|
+
"vocab": VOCABS["french"],
|
|
25
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/parseq-00b40714.onnx",
|
|
26
|
+
},
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class PARSeq(Engine):
|
|
31
|
+
"""PARSeq Onnx loader
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
----
|
|
35
|
+
vocab: vocabulary used for encoding
|
|
36
|
+
cfg: dictionary containing information about the model
|
|
37
|
+
**kwargs: additional arguments to be passed to `Engine`
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
model_path: str,
|
|
43
|
+
vocab: str,
|
|
44
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
45
|
+
**kwargs: Any,
|
|
46
|
+
) -> None:
|
|
47
|
+
super().__init__(url=model_path, **kwargs)
|
|
48
|
+
self.vocab = vocab
|
|
49
|
+
self.cfg = cfg
|
|
50
|
+
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
|
|
51
|
+
|
|
52
|
+
def __call__(
|
|
53
|
+
self,
|
|
54
|
+
x: np.ndarray,
|
|
55
|
+
return_model_output: bool = False,
|
|
56
|
+
) -> Dict[str, Any]:
|
|
57
|
+
logits = self.run(x)
|
|
58
|
+
out: Dict[str, Any] = {}
|
|
59
|
+
|
|
60
|
+
if return_model_output:
|
|
61
|
+
out["out_map"] = logits
|
|
62
|
+
|
|
63
|
+
out["preds"] = self.postprocessor(logits)
|
|
64
|
+
return out
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class PARSeqPostProcessor(RecognitionPostProcessor):
|
|
68
|
+
"""Post processor for PARSeq architecture
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
----
|
|
72
|
+
vocab: string containing the ordered sequence of supported characters
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
vocab: str,
|
|
78
|
+
) -> None:
|
|
79
|
+
super().__init__(vocab)
|
|
80
|
+
self._embedding = list(vocab) + ["<eos>", "<sos>", "<pad>"]
|
|
81
|
+
|
|
82
|
+
def __call__(self, logits):
|
|
83
|
+
# compute pred with argmax for attention models
|
|
84
|
+
out_idxs = np.argmax(logits, axis=-1)
|
|
85
|
+
preds_prob = softmax(logits, axis=-1).max(axis=-1)
|
|
86
|
+
|
|
87
|
+
word_values = [
|
|
88
|
+
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0] for encoded_seq in out_idxs
|
|
89
|
+
]
|
|
90
|
+
# compute probabilties for each word up to the EOS token
|
|
91
|
+
probs = [
|
|
92
|
+
preds_prob[i, : len(word)].clip(0, 1).mean().astype(float) if word else 0.0
|
|
93
|
+
for i, word in enumerate(word_values)
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
return list(zip(word_values, probs))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _parseq(
|
|
100
|
+
arch: str,
|
|
101
|
+
model_path: str,
|
|
102
|
+
**kwargs: Any,
|
|
103
|
+
) -> PARSeq:
|
|
104
|
+
# Patch the config
|
|
105
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
106
|
+
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
107
|
+
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
|
|
108
|
+
|
|
109
|
+
kwargs["vocab"] = _cfg["vocab"]
|
|
110
|
+
|
|
111
|
+
# Build the model
|
|
112
|
+
return PARSeq(model_path, cfg=_cfg, **kwargs)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def parseq(model_path: str = default_cfgs["parseq"]["url"], **kwargs: Any) -> PARSeq:
|
|
116
|
+
"""PARSeq architecture from
|
|
117
|
+
`"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
|
|
118
|
+
|
|
119
|
+
>>> import numpy as np
|
|
120
|
+
>>> from onnxtr.models import parseq
|
|
121
|
+
>>> model = parseq()
|
|
122
|
+
>>> input_tensor = np.random.rand(1, 3, 32, 128)
|
|
123
|
+
>>> out = model(input_tensor)
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
----
|
|
127
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
128
|
+
**kwargs: keyword arguments of the PARSeq architecture
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
-------
|
|
132
|
+
text recognition architecture
|
|
133
|
+
"""
|
|
134
|
+
return _parseq("parseq", model_path, **kwargs)
|