onnxtr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/__init__.py +2 -0
- onnxtr/contrib/__init__.py +0 -0
- onnxtr/contrib/artefacts.py +131 -0
- onnxtr/contrib/base.py +105 -0
- onnxtr/file_utils.py +33 -0
- onnxtr/io/__init__.py +5 -0
- onnxtr/io/elements.py +455 -0
- onnxtr/io/html.py +28 -0
- onnxtr/io/image.py +56 -0
- onnxtr/io/pdf.py +42 -0
- onnxtr/io/reader.py +85 -0
- onnxtr/models/__init__.py +4 -0
- onnxtr/models/_utils.py +141 -0
- onnxtr/models/builder.py +355 -0
- onnxtr/models/classification/__init__.py +2 -0
- onnxtr/models/classification/models/__init__.py +1 -0
- onnxtr/models/classification/models/mobilenet.py +120 -0
- onnxtr/models/classification/predictor/__init__.py +1 -0
- onnxtr/models/classification/predictor/base.py +57 -0
- onnxtr/models/classification/zoo.py +76 -0
- onnxtr/models/detection/__init__.py +2 -0
- onnxtr/models/detection/core.py +101 -0
- onnxtr/models/detection/models/__init__.py +3 -0
- onnxtr/models/detection/models/differentiable_binarization.py +159 -0
- onnxtr/models/detection/models/fast.py +160 -0
- onnxtr/models/detection/models/linknet.py +160 -0
- onnxtr/models/detection/postprocessor/__init__.py +0 -0
- onnxtr/models/detection/postprocessor/base.py +144 -0
- onnxtr/models/detection/predictor/__init__.py +1 -0
- onnxtr/models/detection/predictor/base.py +54 -0
- onnxtr/models/detection/zoo.py +73 -0
- onnxtr/models/engine.py +50 -0
- onnxtr/models/predictor/__init__.py +1 -0
- onnxtr/models/predictor/base.py +175 -0
- onnxtr/models/predictor/predictor.py +145 -0
- onnxtr/models/preprocessor/__init__.py +1 -0
- onnxtr/models/preprocessor/base.py +118 -0
- onnxtr/models/recognition/__init__.py +2 -0
- onnxtr/models/recognition/core.py +28 -0
- onnxtr/models/recognition/models/__init__.py +5 -0
- onnxtr/models/recognition/models/crnn.py +226 -0
- onnxtr/models/recognition/models/master.py +145 -0
- onnxtr/models/recognition/models/parseq.py +134 -0
- onnxtr/models/recognition/models/sar.py +134 -0
- onnxtr/models/recognition/models/vitstr.py +166 -0
- onnxtr/models/recognition/predictor/__init__.py +1 -0
- onnxtr/models/recognition/predictor/_utils.py +86 -0
- onnxtr/models/recognition/predictor/base.py +79 -0
- onnxtr/models/recognition/utils.py +89 -0
- onnxtr/models/recognition/zoo.py +69 -0
- onnxtr/models/zoo.py +114 -0
- onnxtr/transforms/__init__.py +1 -0
- onnxtr/transforms/base.py +112 -0
- onnxtr/utils/__init__.py +4 -0
- onnxtr/utils/common_types.py +18 -0
- onnxtr/utils/data.py +126 -0
- onnxtr/utils/fonts.py +41 -0
- onnxtr/utils/geometry.py +498 -0
- onnxtr/utils/multithreading.py +50 -0
- onnxtr/utils/reconstitution.py +70 -0
- onnxtr/utils/repr.py +64 -0
- onnxtr/utils/visualization.py +291 -0
- onnxtr/utils/vocabs.py +71 -0
- onnxtr/version.py +1 -0
- onnxtr-0.1.0.dist-info/LICENSE +201 -0
- onnxtr-0.1.0.dist-info/METADATA +481 -0
- onnxtr-0.1.0.dist-info/RECORD +70 -0
- onnxtr-0.1.0.dist-info/WHEEL +5 -0
- onnxtr-0.1.0.dist-info/top_level.txt +2 -0
- onnxtr-0.1.0.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy.special import softmax
|
|
11
|
+
|
|
12
|
+
from onnxtr.utils import VOCABS
|
|
13
|
+
|
|
14
|
+
from ...engine import Engine
|
|
15
|
+
from ..core import RecognitionPostProcessor
|
|
16
|
+
|
|
17
|
+
__all__ = ["SAR", "sar_resnet31"]
|
|
18
|
+
|
|
19
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
20
|
+
"sar_resnet31": {
|
|
21
|
+
"mean": (0.694, 0.695, 0.693),
|
|
22
|
+
"std": (0.299, 0.296, 0.301),
|
|
23
|
+
"input_shape": (3, 32, 128),
|
|
24
|
+
"vocab": VOCABS["french"],
|
|
25
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/sar_resnet31-395f8005.onnx",
|
|
26
|
+
},
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SAR(Engine):
|
|
31
|
+
"""SAR Onnx loader
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
----
|
|
35
|
+
model_path: path to onnx model file
|
|
36
|
+
vocab: vocabulary used for encoding
|
|
37
|
+
cfg: dictionary containing information about the model
|
|
38
|
+
**kwargs: additional arguments to be passed to `Engine`
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
model_path: str,
|
|
44
|
+
vocab: str,
|
|
45
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
46
|
+
**kwargs: Any,
|
|
47
|
+
) -> None:
|
|
48
|
+
super().__init__(url=model_path, **kwargs)
|
|
49
|
+
self.vocab = vocab
|
|
50
|
+
self.cfg = cfg
|
|
51
|
+
self.postprocessor = SARPostProcessor(self.vocab)
|
|
52
|
+
|
|
53
|
+
def __call__(
|
|
54
|
+
self,
|
|
55
|
+
x: np.ndarray,
|
|
56
|
+
return_model_output: bool = False,
|
|
57
|
+
) -> Dict[str, Any]:
|
|
58
|
+
logits = self.run(x)
|
|
59
|
+
|
|
60
|
+
out: Dict[str, Any] = {}
|
|
61
|
+
if return_model_output:
|
|
62
|
+
out["out_map"] = logits
|
|
63
|
+
|
|
64
|
+
out["preds"] = self.postprocessor(logits)
|
|
65
|
+
|
|
66
|
+
return out
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class SARPostProcessor(RecognitionPostProcessor):
|
|
70
|
+
"""Post processor for SAR architectures
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
----
|
|
74
|
+
embedding: string containing the ordered sequence of supported characters
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
vocab: str,
|
|
80
|
+
) -> None:
|
|
81
|
+
super().__init__(vocab)
|
|
82
|
+
self._embedding = list(self.vocab) + ["<eos>"]
|
|
83
|
+
|
|
84
|
+
def __call__(self, logits):
|
|
85
|
+
# compute pred with argmax for attention models
|
|
86
|
+
out_idxs = np.argmax(logits, axis=-1)
|
|
87
|
+
# N x L
|
|
88
|
+
probs = np.take_along_axis(softmax(logits, axis=-1), out_idxs[..., None], axis=-1).squeeze(-1)
|
|
89
|
+
# Take the minimum confidence of the sequence
|
|
90
|
+
probs = np.min(probs, axis=1)
|
|
91
|
+
|
|
92
|
+
word_values = [
|
|
93
|
+
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0] for encoded_seq in out_idxs
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
return list(zip(word_values, np.clip(probs, 0, 1).astype(float).tolist()))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _sar(
|
|
100
|
+
arch: str,
|
|
101
|
+
model_path: str,
|
|
102
|
+
**kwargs: Any,
|
|
103
|
+
) -> SAR:
|
|
104
|
+
# Patch the config
|
|
105
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
106
|
+
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
107
|
+
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
|
|
108
|
+
|
|
109
|
+
kwargs["vocab"] = _cfg["vocab"]
|
|
110
|
+
|
|
111
|
+
# Build the model
|
|
112
|
+
return SAR(model_path, cfg=_cfg, **kwargs)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def sar_resnet31(model_path: str = default_cfgs["sar_resnet31"]["url"], **kwargs: Any) -> SAR:
|
|
116
|
+
"""SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
|
|
117
|
+
Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
118
|
+
|
|
119
|
+
>>> import numpy as np
|
|
120
|
+
>>> from onnxtr.models import sar_resnet31
|
|
121
|
+
>>> model = sar_resnet31()
|
|
122
|
+
>>> input_tensor = np.random.rand(1, 3, 32, 128)
|
|
123
|
+
>>> out = model(input_tensor)
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
----
|
|
127
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
128
|
+
**kwargs: keyword arguments of the SAR architecture
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
-------
|
|
132
|
+
text recognition architecture
|
|
133
|
+
"""
|
|
134
|
+
return _sar("sar_resnet31", model_path, **kwargs)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy.special import softmax
|
|
11
|
+
|
|
12
|
+
from onnxtr.utils import VOCABS
|
|
13
|
+
|
|
14
|
+
from ...engine import Engine
|
|
15
|
+
from ..core import RecognitionPostProcessor
|
|
16
|
+
|
|
17
|
+
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
18
|
+
|
|
19
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
20
|
+
"vitstr_small": {
|
|
21
|
+
"mean": (0.694, 0.695, 0.693),
|
|
22
|
+
"std": (0.299, 0.296, 0.301),
|
|
23
|
+
"input_shape": (3, 32, 128),
|
|
24
|
+
"vocab": VOCABS["french"],
|
|
25
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/vitstr_small-3ff9c500.onnx",
|
|
26
|
+
},
|
|
27
|
+
"vitstr_base": {
|
|
28
|
+
"mean": (0.694, 0.695, 0.693),
|
|
29
|
+
"std": (0.299, 0.296, 0.301),
|
|
30
|
+
"input_shape": (3, 32, 128),
|
|
31
|
+
"vocab": VOCABS["french"],
|
|
32
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/vitstr_base-ff62f5be.onnx",
|
|
33
|
+
},
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ViTSTR(Engine):
|
|
38
|
+
"""ViTSTR Onnx loader
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
----
|
|
42
|
+
model_path: path to onnx model file
|
|
43
|
+
vocab: vocabulary used for encoding
|
|
44
|
+
cfg: dictionary containing information about the model
|
|
45
|
+
**kwargs: additional arguments to be passed to `Engine`
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
model_path: str,
|
|
51
|
+
vocab: str,
|
|
52
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
) -> None:
|
|
55
|
+
super().__init__(url=model_path, **kwargs)
|
|
56
|
+
self.vocab = vocab
|
|
57
|
+
self.cfg = cfg
|
|
58
|
+
|
|
59
|
+
self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
|
|
60
|
+
|
|
61
|
+
def __call__(
|
|
62
|
+
self,
|
|
63
|
+
x: np.ndarray,
|
|
64
|
+
return_model_output: bool = False,
|
|
65
|
+
) -> Dict[str, Any]:
|
|
66
|
+
logits = self.run(x)
|
|
67
|
+
|
|
68
|
+
out: Dict[str, Any] = {}
|
|
69
|
+
if return_model_output:
|
|
70
|
+
out["out_map"] = logits
|
|
71
|
+
|
|
72
|
+
out["preds"] = self.postprocessor(logits)
|
|
73
|
+
|
|
74
|
+
return out
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ViTSTRPostProcessor(RecognitionPostProcessor):
|
|
78
|
+
"""Post processor for ViTSTR architecture
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
----
|
|
82
|
+
vocab: string containing the ordered sequence of supported characters
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
vocab: str,
|
|
88
|
+
) -> None:
|
|
89
|
+
super().__init__(vocab)
|
|
90
|
+
self._embedding = list(vocab) + ["<eos>", "<sos>"]
|
|
91
|
+
|
|
92
|
+
def __call__(self, logits):
|
|
93
|
+
# compute pred with argmax for attention models
|
|
94
|
+
out_idxs = np.argmax(logits, axis=-1)
|
|
95
|
+
preds_prob = softmax(logits, axis=-1).max(axis=-1)
|
|
96
|
+
|
|
97
|
+
word_values = [
|
|
98
|
+
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0] for encoded_seq in out_idxs
|
|
99
|
+
]
|
|
100
|
+
# compute probabilties for each word up to the EOS token
|
|
101
|
+
probs = [
|
|
102
|
+
preds_prob[i, : len(word)].clip(0, 1).mean().astype(float) if word else 0.0
|
|
103
|
+
for i, word in enumerate(word_values)
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
return list(zip(word_values, probs))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _vitstr(
|
|
110
|
+
arch: str,
|
|
111
|
+
model_path: str,
|
|
112
|
+
**kwargs: Any,
|
|
113
|
+
) -> ViTSTR:
|
|
114
|
+
# Patch the config
|
|
115
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
116
|
+
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
117
|
+
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
|
|
118
|
+
|
|
119
|
+
kwargs["vocab"] = _cfg["vocab"]
|
|
120
|
+
|
|
121
|
+
# Build the model
|
|
122
|
+
return ViTSTR(model_path, cfg=_cfg, **kwargs)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def vitstr_small(model_path: str = default_cfgs["vitstr_small"]["url"], **kwargs: Any) -> ViTSTR:
|
|
126
|
+
"""ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
|
|
127
|
+
<https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
128
|
+
|
|
129
|
+
>>> import numpy as np
|
|
130
|
+
>>> from onnxtr.models import vitstr_small
|
|
131
|
+
>>> model = vitstr_small()
|
|
132
|
+
>>> input_tensor = np.random.rand(1, 3, 32, 128)
|
|
133
|
+
>>> out = model(input_tensor)
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
----
|
|
137
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
138
|
+
kwargs: keyword arguments of the ViTSTR architecture
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
-------
|
|
142
|
+
text recognition architecture
|
|
143
|
+
"""
|
|
144
|
+
return _vitstr("vitstr_small", model_path, **kwargs)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def vitstr_base(model_path: str = default_cfgs["vitstr_base"]["url"], **kwargs: Any) -> ViTSTR:
|
|
148
|
+
"""ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
|
|
149
|
+
<https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
150
|
+
|
|
151
|
+
>>> import numpy as np
|
|
152
|
+
>>> from onnxtr.models import vitstr_base
|
|
153
|
+
>>> model = vitstr_base()
|
|
154
|
+
>>> input_tensor = np.random.rand(1, 3, 32, 128)
|
|
155
|
+
>>> out = model(input_tensor)
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
----
|
|
159
|
+
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
160
|
+
kwargs: keyword arguments of the ViTSTR architecture
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
-------
|
|
164
|
+
text recognition architecture
|
|
165
|
+
"""
|
|
166
|
+
return _vitstr("vitstr_base", model_path, **kwargs)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import *
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import List, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from ..utils import merge_multi_strings
|
|
11
|
+
|
|
12
|
+
__all__ = ["split_crops", "remap_preds"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def split_crops(
|
|
16
|
+
crops: List[np.ndarray],
|
|
17
|
+
max_ratio: float,
|
|
18
|
+
target_ratio: int,
|
|
19
|
+
dilation: float,
|
|
20
|
+
channels_last: bool = True,
|
|
21
|
+
) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]:
|
|
22
|
+
"""Chunk crops horizontally to match a given aspect ratio
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
----
|
|
26
|
+
crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
|
|
27
|
+
max_ratio: the maximum aspect ratio that won't trigger the chunk
|
|
28
|
+
target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
|
|
29
|
+
dilation: the width dilation of final chunks (to provide some overlaps)
|
|
30
|
+
channels_last: whether the numpy array has dimensions in channels last order
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
-------
|
|
34
|
+
a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
|
|
35
|
+
"""
|
|
36
|
+
_remap_required = False
|
|
37
|
+
crop_map: List[Union[int, Tuple[int, int]]] = []
|
|
38
|
+
new_crops: List[np.ndarray] = []
|
|
39
|
+
for crop in crops:
|
|
40
|
+
h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
|
|
41
|
+
aspect_ratio = w / h
|
|
42
|
+
if aspect_ratio > max_ratio:
|
|
43
|
+
# Determine the number of crops, reference aspect ratio = 4 = 128 / 32
|
|
44
|
+
num_subcrops = int(aspect_ratio // target_ratio)
|
|
45
|
+
# Find the new widths, additional dilation factor to overlap crops
|
|
46
|
+
width = dilation * w / num_subcrops
|
|
47
|
+
centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)]
|
|
48
|
+
# Get the crops
|
|
49
|
+
if channels_last:
|
|
50
|
+
_crops = [
|
|
51
|
+
crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :]
|
|
52
|
+
for center in centers
|
|
53
|
+
]
|
|
54
|
+
else:
|
|
55
|
+
_crops = [
|
|
56
|
+
crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))]
|
|
57
|
+
for center in centers
|
|
58
|
+
]
|
|
59
|
+
# Avoid sending zero-sized crops
|
|
60
|
+
_crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
|
|
61
|
+
# Record the slice of crops
|
|
62
|
+
crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
|
|
63
|
+
new_crops.extend(_crops)
|
|
64
|
+
# At least one crop will require merging
|
|
65
|
+
_remap_required = True
|
|
66
|
+
else:
|
|
67
|
+
crop_map.append(len(new_crops))
|
|
68
|
+
new_crops.append(crop)
|
|
69
|
+
|
|
70
|
+
return new_crops, crop_map, _remap_required
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def remap_preds(
|
|
74
|
+
preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float
|
|
75
|
+
) -> List[Tuple[str, float]]:
|
|
76
|
+
remapped_out = []
|
|
77
|
+
for _idx in crop_map:
|
|
78
|
+
# Crop hasn't been split
|
|
79
|
+
if isinstance(_idx, int):
|
|
80
|
+
remapped_out.append(preds[_idx])
|
|
81
|
+
else:
|
|
82
|
+
# unzip
|
|
83
|
+
vals, probs = zip(*preds[_idx[0] : _idx[1]])
|
|
84
|
+
# Merge the string values
|
|
85
|
+
remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type]
|
|
86
|
+
return remapped_out
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, List, Sequence, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from onnxtr.models.preprocessor import PreProcessor
|
|
11
|
+
from onnxtr.utils.repr import NestedObject
|
|
12
|
+
|
|
13
|
+
from ._utils import remap_preds, split_crops
|
|
14
|
+
|
|
15
|
+
__all__ = ["RecognitionPredictor"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RecognitionPredictor(NestedObject):
|
|
19
|
+
"""Implements an object able to identify character sequences in images
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
----
|
|
23
|
+
pre_processor: transform inputs for easier batched model inference
|
|
24
|
+
model: core recognition architecture
|
|
25
|
+
split_wide_crops: wether to use crop splitting for high aspect ratio crops
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
pre_processor: PreProcessor,
|
|
31
|
+
model: Any,
|
|
32
|
+
split_wide_crops: bool = True,
|
|
33
|
+
) -> None:
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.pre_processor = pre_processor
|
|
36
|
+
self.model = model
|
|
37
|
+
self.split_wide_crops = split_wide_crops
|
|
38
|
+
self.critical_ar = 8 # Critical aspect ratio
|
|
39
|
+
self.dil_factor = 1.4 # Dilation factor to overlap the crops
|
|
40
|
+
self.target_ar = 6 # Target aspect ratio
|
|
41
|
+
|
|
42
|
+
def __call__(
|
|
43
|
+
self,
|
|
44
|
+
crops: Sequence[np.ndarray],
|
|
45
|
+
**kwargs: Any,
|
|
46
|
+
) -> List[Tuple[str, float]]:
|
|
47
|
+
if len(crops) == 0:
|
|
48
|
+
return []
|
|
49
|
+
# Dimension check
|
|
50
|
+
if any(crop.ndim != 3 for crop in crops):
|
|
51
|
+
raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
|
|
52
|
+
|
|
53
|
+
# Split crops that are too wide
|
|
54
|
+
remapped = False
|
|
55
|
+
if self.split_wide_crops:
|
|
56
|
+
new_crops, crop_map, remapped = split_crops(
|
|
57
|
+
crops, # type: ignore[arg-type]
|
|
58
|
+
self.critical_ar,
|
|
59
|
+
self.target_ar,
|
|
60
|
+
self.dil_factor,
|
|
61
|
+
True,
|
|
62
|
+
)
|
|
63
|
+
if remapped:
|
|
64
|
+
crops = new_crops
|
|
65
|
+
|
|
66
|
+
# Resize & batch them
|
|
67
|
+
processed_batches = self.pre_processor(crops) # type: ignore[arg-type]
|
|
68
|
+
|
|
69
|
+
# Forward it
|
|
70
|
+
raw = [self.model(batch, **kwargs)["preds"] for batch in processed_batches]
|
|
71
|
+
|
|
72
|
+
# Process outputs
|
|
73
|
+
out = [charseq for batch in raw for charseq in batch]
|
|
74
|
+
|
|
75
|
+
# Remap crops
|
|
76
|
+
if self.split_wide_crops and remapped:
|
|
77
|
+
out = remap_preds(out, crop_map, self.dil_factor)
|
|
78
|
+
|
|
79
|
+
return out
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
from rapidfuzz.distance import Levenshtein
|
|
9
|
+
|
|
10
|
+
__all__ = ["merge_strings", "merge_multi_strings"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def merge_strings(a: str, b: str, dil_factor: float) -> str:
|
|
14
|
+
"""Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
----
|
|
18
|
+
a: first char seq, suffix should be similar to b's prefix.
|
|
19
|
+
b: second char seq, prefix should be similar to a's suffix.
|
|
20
|
+
dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
|
|
21
|
+
only used when the mother sequence is splitted on a character repetition
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
-------
|
|
25
|
+
A merged character sequence.
|
|
26
|
+
|
|
27
|
+
Example::
|
|
28
|
+
>>> from onnxtr.model.recognition.utils import merge_sequences
|
|
29
|
+
>>> merge_sequences('abcd', 'cdefgh', 1.4)
|
|
30
|
+
'abcdefgh'
|
|
31
|
+
>>> merge_sequences('abcdi', 'cdefgh', 1.4)
|
|
32
|
+
'abcdefgh'
|
|
33
|
+
"""
|
|
34
|
+
seq_len = min(len(a), len(b))
|
|
35
|
+
if seq_len == 0: # One sequence is empty, return the other
|
|
36
|
+
return b if len(a) == 0 else a
|
|
37
|
+
|
|
38
|
+
# Initialize merging index and corresponding score (mean Levenstein)
|
|
39
|
+
min_score, index = 1.0, 0 # No overlap, just concatenate
|
|
40
|
+
|
|
41
|
+
scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
|
|
42
|
+
|
|
43
|
+
# Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
|
|
44
|
+
if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
|
|
45
|
+
# Compute n_overlap (number of overlapping chars, geometrically determined)
|
|
46
|
+
n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
|
|
47
|
+
# Find the number of consecutive zeros in the scores list
|
|
48
|
+
# Impossible to have a zero after a non-zero score in that case
|
|
49
|
+
n_zeros = sum(val == 0 for val in scores)
|
|
50
|
+
# Index is bounded by the geometrical overlap to avoid collapsing repetitions
|
|
51
|
+
min_score, index = 0, min(n_zeros, n_overlap)
|
|
52
|
+
|
|
53
|
+
else: # Common case: choose the min score index
|
|
54
|
+
for i, score in enumerate(scores):
|
|
55
|
+
if score < min_score:
|
|
56
|
+
min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
|
|
57
|
+
|
|
58
|
+
# Merge with correct overlap
|
|
59
|
+
if index == 0:
|
|
60
|
+
return a + b
|
|
61
|
+
return a[:-1] + b[index - 1 :]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
|
|
65
|
+
"""Recursively merges consecutive string sequences with overlapping characters.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
----
|
|
69
|
+
seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
|
|
70
|
+
dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
|
|
71
|
+
only used when the mother sequence is splitted on a character repetition
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
-------
|
|
75
|
+
A merged character sequence
|
|
76
|
+
|
|
77
|
+
Example::
|
|
78
|
+
>>> from onnxtr.model.recognition.utils import merge_multi_sequences
|
|
79
|
+
>>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
|
|
80
|
+
'abcdefghijkl'
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str:
|
|
84
|
+
# Recursive version of compute_overlap
|
|
85
|
+
if len(seq_list) == 1:
|
|
86
|
+
return merge_strings(a, seq_list[0], dil_factor)
|
|
87
|
+
return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor)
|
|
88
|
+
|
|
89
|
+
return _recursive_merge("", seq_list, dil_factor)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, List
|
|
7
|
+
|
|
8
|
+
from onnxtr.models.preprocessor import PreProcessor
|
|
9
|
+
|
|
10
|
+
from .. import recognition
|
|
11
|
+
from .predictor import RecognitionPredictor
|
|
12
|
+
|
|
13
|
+
__all__ = ["recognition_predictor"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
ARCHS: List[str] = [
|
|
17
|
+
"crnn_vgg16_bn",
|
|
18
|
+
"crnn_mobilenet_v3_small",
|
|
19
|
+
"crnn_mobilenet_v3_large",
|
|
20
|
+
"sar_resnet31",
|
|
21
|
+
"master",
|
|
22
|
+
"vitstr_small",
|
|
23
|
+
"vitstr_base",
|
|
24
|
+
"parseq",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _predictor(arch: Any, **kwargs: Any) -> RecognitionPredictor:
|
|
29
|
+
if isinstance(arch, str):
|
|
30
|
+
if arch not in ARCHS:
|
|
31
|
+
raise ValueError(f"unknown architecture '{arch}'")
|
|
32
|
+
|
|
33
|
+
_model = recognition.__dict__[arch]()
|
|
34
|
+
else:
|
|
35
|
+
if not isinstance(
|
|
36
|
+
arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
|
|
37
|
+
):
|
|
38
|
+
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
39
|
+
_model = arch
|
|
40
|
+
|
|
41
|
+
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
42
|
+
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
43
|
+
kwargs["batch_size"] = kwargs.get("batch_size", 1024)
|
|
44
|
+
input_shape = _model.cfg["input_shape"][1:]
|
|
45
|
+
predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
|
|
46
|
+
|
|
47
|
+
return predictor
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def recognition_predictor(arch: Any = "crnn_vgg16_bn", **kwargs: Any) -> RecognitionPredictor:
|
|
51
|
+
"""Text recognition architecture.
|
|
52
|
+
|
|
53
|
+
Example::
|
|
54
|
+
>>> import numpy as np
|
|
55
|
+
>>> from onnxtr.models import recognition_predictor
|
|
56
|
+
>>> model = recognition_predictor()
|
|
57
|
+
>>> input_page = (255 * np.random.rand(32, 128, 3)).astype(np.uint8)
|
|
58
|
+
>>> out = model([input_page])
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
----
|
|
62
|
+
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
|
|
63
|
+
**kwargs: optional parameters to be passed to the architecture
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
-------
|
|
67
|
+
Recognition predictor
|
|
68
|
+
"""
|
|
69
|
+
return _predictor(arch, **kwargs)
|