onnxtr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. onnxtr/__init__.py +2 -0
  2. onnxtr/contrib/__init__.py +0 -0
  3. onnxtr/contrib/artefacts.py +131 -0
  4. onnxtr/contrib/base.py +105 -0
  5. onnxtr/file_utils.py +33 -0
  6. onnxtr/io/__init__.py +5 -0
  7. onnxtr/io/elements.py +455 -0
  8. onnxtr/io/html.py +28 -0
  9. onnxtr/io/image.py +56 -0
  10. onnxtr/io/pdf.py +42 -0
  11. onnxtr/io/reader.py +85 -0
  12. onnxtr/models/__init__.py +4 -0
  13. onnxtr/models/_utils.py +141 -0
  14. onnxtr/models/builder.py +355 -0
  15. onnxtr/models/classification/__init__.py +2 -0
  16. onnxtr/models/classification/models/__init__.py +1 -0
  17. onnxtr/models/classification/models/mobilenet.py +120 -0
  18. onnxtr/models/classification/predictor/__init__.py +1 -0
  19. onnxtr/models/classification/predictor/base.py +57 -0
  20. onnxtr/models/classification/zoo.py +76 -0
  21. onnxtr/models/detection/__init__.py +2 -0
  22. onnxtr/models/detection/core.py +101 -0
  23. onnxtr/models/detection/models/__init__.py +3 -0
  24. onnxtr/models/detection/models/differentiable_binarization.py +159 -0
  25. onnxtr/models/detection/models/fast.py +160 -0
  26. onnxtr/models/detection/models/linknet.py +160 -0
  27. onnxtr/models/detection/postprocessor/__init__.py +0 -0
  28. onnxtr/models/detection/postprocessor/base.py +144 -0
  29. onnxtr/models/detection/predictor/__init__.py +1 -0
  30. onnxtr/models/detection/predictor/base.py +54 -0
  31. onnxtr/models/detection/zoo.py +73 -0
  32. onnxtr/models/engine.py +50 -0
  33. onnxtr/models/predictor/__init__.py +1 -0
  34. onnxtr/models/predictor/base.py +175 -0
  35. onnxtr/models/predictor/predictor.py +145 -0
  36. onnxtr/models/preprocessor/__init__.py +1 -0
  37. onnxtr/models/preprocessor/base.py +118 -0
  38. onnxtr/models/recognition/__init__.py +2 -0
  39. onnxtr/models/recognition/core.py +28 -0
  40. onnxtr/models/recognition/models/__init__.py +5 -0
  41. onnxtr/models/recognition/models/crnn.py +226 -0
  42. onnxtr/models/recognition/models/master.py +145 -0
  43. onnxtr/models/recognition/models/parseq.py +134 -0
  44. onnxtr/models/recognition/models/sar.py +134 -0
  45. onnxtr/models/recognition/models/vitstr.py +166 -0
  46. onnxtr/models/recognition/predictor/__init__.py +1 -0
  47. onnxtr/models/recognition/predictor/_utils.py +86 -0
  48. onnxtr/models/recognition/predictor/base.py +79 -0
  49. onnxtr/models/recognition/utils.py +89 -0
  50. onnxtr/models/recognition/zoo.py +69 -0
  51. onnxtr/models/zoo.py +114 -0
  52. onnxtr/transforms/__init__.py +1 -0
  53. onnxtr/transforms/base.py +112 -0
  54. onnxtr/utils/__init__.py +4 -0
  55. onnxtr/utils/common_types.py +18 -0
  56. onnxtr/utils/data.py +126 -0
  57. onnxtr/utils/fonts.py +41 -0
  58. onnxtr/utils/geometry.py +498 -0
  59. onnxtr/utils/multithreading.py +50 -0
  60. onnxtr/utils/reconstitution.py +70 -0
  61. onnxtr/utils/repr.py +64 -0
  62. onnxtr/utils/visualization.py +291 -0
  63. onnxtr/utils/vocabs.py +71 -0
  64. onnxtr/version.py +1 -0
  65. onnxtr-0.1.0.dist-info/LICENSE +201 -0
  66. onnxtr-0.1.0.dist-info/METADATA +481 -0
  67. onnxtr-0.1.0.dist-info/RECORD +70 -0
  68. onnxtr-0.1.0.dist-info/WHEEL +5 -0
  69. onnxtr-0.1.0.dist-info/top_level.txt +2 -0
  70. onnxtr-0.1.0.dist-info/zip-safe +1 -0
@@ -0,0 +1,118 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import math
7
+ from typing import Any, List, Tuple, Union
8
+
9
+ import numpy as np
10
+
11
+ from onnxtr.transforms import Normalize, Resize
12
+ from onnxtr.utils.geometry import shape_translate
13
+ from onnxtr.utils.multithreading import multithread_exec
14
+ from onnxtr.utils.repr import NestedObject
15
+
16
+ __all__ = ["PreProcessor"]
17
+
18
+
19
+ class PreProcessor(NestedObject):
20
+ """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
21
+
22
+ Args:
23
+ ----
24
+ output_size: expected size of each page in format (H, W)
25
+ batch_size: the size of page batches
26
+ mean: mean value of the training distribution by channel
27
+ std: standard deviation of the training distribution by channel
28
+ """
29
+
30
+ _children_names: List[str] = ["resize", "normalize"]
31
+
32
+ def __init__(
33
+ self,
34
+ output_size: Tuple[int, int],
35
+ batch_size: int,
36
+ mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
37
+ std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
38
+ **kwargs: Any,
39
+ ) -> None:
40
+ self.batch_size = batch_size
41
+ self.resize = Resize(output_size, **kwargs)
42
+ self.normalize = Normalize(mean, std)
43
+
44
+ def batch_inputs(self, samples: List[np.ndarray]) -> List[np.ndarray]:
45
+ """Gather samples into batches for inference purposes
46
+
47
+ Args:
48
+ ----
49
+ samples: list of samples (tf.Tensor)
50
+
51
+ Returns:
52
+ -------
53
+ list of batched samples
54
+ """
55
+ num_batches = int(math.ceil(len(samples) / self.batch_size))
56
+ batches = [
57
+ np.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0)
58
+ for idx in range(int(num_batches))
59
+ ]
60
+
61
+ return batches
62
+
63
+ def sample_transforms(self, x: np.ndarray) -> np.ndarray:
64
+ if x.ndim != 3:
65
+ raise AssertionError("expected list of 3D Tensors")
66
+ if isinstance(x, np.ndarray):
67
+ if x.dtype not in (np.uint8, np.float32):
68
+ raise TypeError("unsupported data type for numpy.ndarray")
69
+ x = shape_translate(x, "HWC")
70
+ # Data type & 255 division
71
+ if x.dtype == np.uint8:
72
+ x = x.astype(np.float32) / 255.0
73
+ # Resizing
74
+ x = self.resize(x)
75
+
76
+ return x
77
+
78
+ def __call__(self, x: Union[np.ndarray, List[np.ndarray]]) -> List[np.ndarray]:
79
+ """Prepare document data for model forwarding
80
+
81
+ Args:
82
+ ----
83
+ x: list of images (np.array) or tensors (already resized and batched)
84
+
85
+ Returns:
86
+ -------
87
+ list of page batches
88
+ """
89
+ # Input type check
90
+ if isinstance(x, np.ndarray):
91
+ if x.ndim != 4:
92
+ raise AssertionError("expected 4D Tensor")
93
+ if isinstance(x, np.ndarray):
94
+ if x.dtype not in (np.uint8, np.float32):
95
+ raise TypeError("unsupported data type for numpy.ndarray")
96
+ x = shape_translate(x, "BHWC")
97
+
98
+ # Data type & 255 division
99
+ if x.dtype == np.uint8:
100
+ x = x.astype(np.float32) / 255.0
101
+ # Resizing
102
+ if (x.shape[1], x.shape[2]) != self.resize.output_size:
103
+ x = np.array([self.resize(sample) for sample in x])
104
+
105
+ batches = [x]
106
+
107
+ elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
108
+ # Sample transform (to tensor, resize)
109
+ samples = list(multithread_exec(self.sample_transforms, x))
110
+ # Batching
111
+ batches = self.batch_inputs(samples)
112
+ else:
113
+ raise TypeError(f"invalid input type: {type(x)}")
114
+
115
+ # Batch transforms (normalize)
116
+ batches = list(multithread_exec(self.normalize, batches))
117
+
118
+ return batches
@@ -0,0 +1,2 @@
1
+ from .models import *
2
+ from .zoo import *
@@ -0,0 +1,28 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+
7
+ from onnxtr.utils.repr import NestedObject
8
+
9
+ __all__ = ["RecognitionPostProcessor"]
10
+
11
+
12
+ class RecognitionPostProcessor(NestedObject):
13
+ """Abstract class to postprocess the raw output of the model
14
+
15
+ Args:
16
+ ----
17
+ vocab: string containing the ordered sequence of supported characters
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ vocab: str,
23
+ ) -> None:
24
+ self.vocab = vocab
25
+ self._embedding = list(self.vocab) + ["<eos>"]
26
+
27
+ def extra_repr(self) -> str:
28
+ return f"vocab_size={len(self.vocab)}"
@@ -0,0 +1,5 @@
1
+ from .crnn import *
2
+ from .sar import *
3
+ from .master import *
4
+ from .vitstr import *
5
+ from .parseq import *
@@ -0,0 +1,226 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from copy import deepcopy
7
+ from itertools import groupby
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import numpy as np
11
+ from scipy.special import softmax
12
+
13
+ from onnxtr.utils import VOCABS
14
+
15
+ from ...engine import Engine
16
+ from ..core import RecognitionPostProcessor
17
+
18
+ __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
19
+
20
+ default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ "crnn_vgg16_bn": {
22
+ "mean": (0.694, 0.695, 0.693),
23
+ "std": (0.299, 0.296, 0.301),
24
+ "input_shape": (3, 32, 128),
25
+ "vocab": VOCABS["legacy_french"],
26
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_vgg16_bn-662979cc.onnx",
27
+ },
28
+ "crnn_mobilenet_v3_small": {
29
+ "mean": (0.694, 0.695, 0.693),
30
+ "std": (0.299, 0.296, 0.301),
31
+ "input_shape": (3, 32, 128),
32
+ "vocab": VOCABS["french"],
33
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_mobilenet_v3_small-bded4d49.onnx",
34
+ },
35
+ "crnn_mobilenet_v3_large": {
36
+ "mean": (0.694, 0.695, 0.693),
37
+ "std": (0.299, 0.296, 0.301),
38
+ "input_shape": (3, 32, 128),
39
+ "vocab": VOCABS["french"],
40
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_mobilenet_v3_large-d42e8185.onnx",
41
+ },
42
+ }
43
+
44
+
45
+ class CRNNPostProcessor(RecognitionPostProcessor):
46
+ """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
47
+
48
+ Args:
49
+ ----
50
+ vocab: string containing the ordered sequence of supported characters
51
+ """
52
+
53
+ def __init__(self, vocab):
54
+ self.vocab = vocab
55
+
56
+ def decode_sequence(self, sequence, vocab):
57
+ return "".join([vocab[int(char)] for char in sequence])
58
+
59
+ def ctc_best_path(
60
+ self,
61
+ logits,
62
+ vocab,
63
+ blank=0,
64
+ ):
65
+ """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
66
+ <https://github.com/githubharald/CTCDecoder>`_.
67
+
68
+ Args:
69
+ ----
70
+ logits: model output, shape: N x T x C
71
+ vocab: vocabulary to use
72
+ blank: index of blank label
73
+
74
+ Returns:
75
+ -------
76
+ A list of tuples: (word, confidence)
77
+ """
78
+ # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
79
+ probs = softmax(logits, axis=-1).max(axis=-1).min(axis=1)
80
+
81
+ # collapse best path (using itertools.groupby), map to chars, join char list to string
82
+ words = [
83
+ self.decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
84
+ for seq in np.argmax(logits, axis=-1)
85
+ ]
86
+
87
+ return list(zip(words, probs.astype(float).tolist()))
88
+
89
+ def __call__(self, logits):
90
+ """Performs decoding of raw output with CTC and decoding of CTC predictions
91
+ with label_to_idx mapping dictionnary
92
+
93
+ Args:
94
+ ----
95
+ logits: raw output of the model, shape (N, C + 1, seq_len)
96
+
97
+ Returns:
98
+ -------
99
+ A tuple of 2 lists: a list of str (words) and a list of float (probs)
100
+
101
+ """
102
+ # Decode CTC
103
+ return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
104
+
105
+
106
+ class CRNN(Engine):
107
+ """CRNN Onnx loader
108
+
109
+ Args:
110
+ ----
111
+ model_path: path or url to onnx model file
112
+ vocab: vocabulary used for encoding
113
+ cfg: configuration dictionary
114
+ **kwargs: additional arguments to be passed to `Engine`
115
+ """
116
+
117
+ _children_names: List[str] = ["postprocessor"]
118
+
119
+ def __init__(
120
+ self,
121
+ model_path: str,
122
+ vocab: str,
123
+ cfg: Optional[Dict[str, Any]] = None,
124
+ **kwargs: Any,
125
+ ) -> None:
126
+ super().__init__(url=model_path, **kwargs)
127
+ self.vocab = vocab
128
+ self.cfg = cfg
129
+ self.postprocessor = CRNNPostProcessor(self.vocab)
130
+
131
+ def __call__(
132
+ self,
133
+ x: np.ndarray,
134
+ return_model_output: bool = False,
135
+ ) -> Dict[str, Any]:
136
+ logits = self.run(x)
137
+
138
+ out: Dict[str, Any] = {}
139
+ if return_model_output:
140
+ out["out_map"] = logits
141
+
142
+ # Post-process
143
+ out["preds"] = self.postprocessor(logits)
144
+
145
+ return out
146
+
147
+
148
+ def _crnn(
149
+ arch: str,
150
+ model_path: str,
151
+ **kwargs: Any,
152
+ ) -> CRNN:
153
+ kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
154
+
155
+ _cfg = deepcopy(default_cfgs[arch])
156
+ _cfg["vocab"] = kwargs["vocab"]
157
+ _cfg["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
158
+
159
+ # Build the model
160
+ return CRNN(model_path, cfg=_cfg, **kwargs)
161
+
162
+
163
+ def crnn_vgg16_bn(model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], **kwargs: Any) -> CRNN:
164
+ """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
165
+ Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
166
+
167
+ >>> import numpy as np
168
+ >>> from onnxtr.models import crnn_vgg16_bn
169
+ >>> model = crnn_vgg16_bn()
170
+ >>> input_tensor = np.random.rand(1, 3, 32, 128)
171
+ >>> out = model(input_tensor)
172
+
173
+ Args:
174
+ ----
175
+ model_path: path to onnx model file, defaults to url in default_cfgs
176
+ **kwargs: keyword arguments of the CRNN architecture
177
+
178
+ Returns:
179
+ -------
180
+ text recognition architecture
181
+ """
182
+ return _crnn("crnn_vgg16_bn", model_path, **kwargs)
183
+
184
+
185
+ def crnn_mobilenet_v3_small(model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"], **kwargs: Any) -> CRNN:
186
+ """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
187
+ Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
188
+
189
+ >>> import numpy as np
190
+ >>> from onnxtr.models import crnn_mobilenet_v3_small
191
+ >>> model = crnn_mobilenet_v3_small()
192
+ >>> input_tensor = np.random.rand(1, 3, 32, 128)
193
+ >>> out = model(input_tensor)
194
+
195
+ Args:
196
+ ----
197
+ model_path: path to onnx model file, defaults to url in default_cfgs
198
+ **kwargs: keyword arguments of the CRNN architecture
199
+
200
+ Returns:
201
+ -------
202
+ text recognition architecture
203
+ """
204
+ return _crnn("crnn_mobilenet_v3_small", model_path, **kwargs)
205
+
206
+
207
+ def crnn_mobilenet_v3_large(model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"], **kwargs: Any) -> CRNN:
208
+ """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
209
+ Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
210
+
211
+ >>> import numpy as np
212
+ >>> from onnxtr.models import crnn_mobilenet_v3_large
213
+ >>> model = crnn_mobilenet_v3_large()
214
+ >>> input_tensor = np.random.rand(1, 3, 32, 128)
215
+ >>> out = model(input_tensor)
216
+
217
+ Args:
218
+ ----
219
+ model_path: path to onnx model file, defaults to url in default_cfgs
220
+ **kwargs: keyword arguments of the CRNN architecture
221
+
222
+ Returns:
223
+ -------
224
+ text recognition architecture
225
+ """
226
+ return _crnn("crnn_mobilenet_v3_large", model_path, **kwargs)
@@ -0,0 +1,145 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from copy import deepcopy
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ from scipy.special import softmax
11
+
12
+ from onnxtr.utils import VOCABS
13
+
14
+ from ...engine import Engine
15
+ from ..core import RecognitionPostProcessor
16
+
17
+ __all__ = ["MASTER", "master"]
18
+
19
+
20
+ default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ "master": {
22
+ "mean": (0.694, 0.695, 0.693),
23
+ "std": (0.299, 0.296, 0.301),
24
+ "input_shape": (3, 32, 128),
25
+ "vocab": VOCABS["french"],
26
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/master-b1287fcd.onnx",
27
+ },
28
+ }
29
+
30
+
31
+ class MASTER(Engine):
32
+ """MASTER Onnx loader
33
+
34
+ Args:
35
+ ----
36
+ model_path: path or url to onnx model file
37
+ vocab: vocabulary, (without EOS, SOS, PAD)
38
+ cfg: dictionary containing information about the model
39
+ **kwargs: additional arguments to be passed to `Engine`
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ model_path: str,
45
+ vocab: str,
46
+ cfg: Optional[Dict[str, Any]] = None,
47
+ **kwargs: Any,
48
+ ) -> None:
49
+ super().__init__(url=model_path, **kwargs)
50
+
51
+ self.vocab = vocab
52
+ self.cfg = cfg
53
+ self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
54
+
55
+ def __call__(
56
+ self,
57
+ x: np.ndarray,
58
+ return_model_output: bool = False,
59
+ ) -> Dict[str, Any]:
60
+ """Call function
61
+
62
+ Args:
63
+ ----
64
+ x: images
65
+ return_model_output: if True, return logits
66
+
67
+ Returns:
68
+ -------
69
+ A dictionnary containing eventually logits and predictions.
70
+ """
71
+ logits = self.run(x)
72
+ out: Dict[str, Any] = {}
73
+
74
+ if return_model_output:
75
+ out["out_map"] = logits
76
+
77
+ out["preds"] = self.postprocessor(logits)
78
+
79
+ return out
80
+
81
+
82
+ class MASTERPostProcessor(RecognitionPostProcessor):
83
+ """Post-processor for the MASTER model
84
+
85
+ Args:
86
+ ----
87
+ vocab: string containing the ordered sequence of supported characters
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ vocab: str,
93
+ ) -> None:
94
+ super().__init__(vocab)
95
+ self._embedding = list(vocab) + ["<eos>"] + ["<sos>"] + ["<pad>"]
96
+
97
+ def __call__(self, logits: np.ndarray) -> List[Tuple[str, float]]:
98
+ # compute pred with argmax for attention models
99
+ out_idxs = np.argmax(logits, axis=-1)
100
+ # N x L
101
+ probs = np.take_along_axis(softmax(logits, axis=-1), out_idxs[..., None], axis=-1).squeeze(-1)
102
+ # Take the minimum confidence of the sequence
103
+ probs = np.min(probs, axis=1)
104
+
105
+ word_values = [
106
+ "".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0] for encoded_seq in out_idxs
107
+ ]
108
+
109
+ return list(zip(word_values, np.clip(probs, 0, 1).astype(float).tolist()))
110
+
111
+
112
+ def _master(
113
+ arch: str,
114
+ model_path: str,
115
+ **kwargs: Any,
116
+ ) -> MASTER:
117
+ # Patch the config
118
+ _cfg = deepcopy(default_cfgs[arch])
119
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
120
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
121
+
122
+ kwargs["vocab"] = _cfg["vocab"]
123
+
124
+ return MASTER(model_path, cfg=_cfg, **kwargs)
125
+
126
+
127
+ def master(model_path: str = default_cfgs["master"]["url"], **kwargs: Any) -> MASTER:
128
+ """MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
129
+
130
+ >>> import numpy as np
131
+ >>> from onnxtr.models import master
132
+ >>> model = master()
133
+ >>> input_tensor = np.random.rand(1, 3, 32, 128)
134
+ >>> out = model(input_tensor)
135
+
136
+ Args:
137
+ ----
138
+ model_path: path to onnx model file, defaults to url in default_cfgs
139
+ **kwargs: keywoard arguments passed to the MASTER architecture
140
+
141
+ Returns:
142
+ -------
143
+ text recognition architecture
144
+ """
145
+ return _master("master", model_path, **kwargs)
@@ -0,0 +1,134 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from copy import deepcopy
7
+ from typing import Any, Dict, Optional
8
+
9
+ import numpy as np
10
+ from scipy.special import softmax
11
+
12
+ from onnxtr.utils import VOCABS
13
+
14
+ from ...engine import Engine
15
+ from ..core import RecognitionPostProcessor
16
+
17
+ __all__ = ["PARSeq", "parseq"]
18
+
19
+ default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ "parseq": {
21
+ "mean": (0.694, 0.695, 0.693),
22
+ "std": (0.299, 0.296, 0.301),
23
+ "input_shape": (3, 32, 128),
24
+ "vocab": VOCABS["french"],
25
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/parseq-00b40714.onnx",
26
+ },
27
+ }
28
+
29
+
30
+ class PARSeq(Engine):
31
+ """PARSeq Onnx loader
32
+
33
+ Args:
34
+ ----
35
+ vocab: vocabulary used for encoding
36
+ cfg: dictionary containing information about the model
37
+ **kwargs: additional arguments to be passed to `Engine`
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ model_path: str,
43
+ vocab: str,
44
+ cfg: Optional[Dict[str, Any]] = None,
45
+ **kwargs: Any,
46
+ ) -> None:
47
+ super().__init__(url=model_path, **kwargs)
48
+ self.vocab = vocab
49
+ self.cfg = cfg
50
+ self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
51
+
52
+ def __call__(
53
+ self,
54
+ x: np.ndarray,
55
+ return_model_output: bool = False,
56
+ ) -> Dict[str, Any]:
57
+ logits = self.run(x)
58
+ out: Dict[str, Any] = {}
59
+
60
+ if return_model_output:
61
+ out["out_map"] = logits
62
+
63
+ out["preds"] = self.postprocessor(logits)
64
+ return out
65
+
66
+
67
+ class PARSeqPostProcessor(RecognitionPostProcessor):
68
+ """Post processor for PARSeq architecture
69
+
70
+ Args:
71
+ ----
72
+ vocab: string containing the ordered sequence of supported characters
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ vocab: str,
78
+ ) -> None:
79
+ super().__init__(vocab)
80
+ self._embedding = list(vocab) + ["<eos>", "<sos>", "<pad>"]
81
+
82
+ def __call__(self, logits):
83
+ # compute pred with argmax for attention models
84
+ out_idxs = np.argmax(logits, axis=-1)
85
+ preds_prob = softmax(logits, axis=-1).max(axis=-1)
86
+
87
+ word_values = [
88
+ "".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0] for encoded_seq in out_idxs
89
+ ]
90
+ # compute probabilties for each word up to the EOS token
91
+ probs = [
92
+ preds_prob[i, : len(word)].clip(0, 1).mean().astype(float) if word else 0.0
93
+ for i, word in enumerate(word_values)
94
+ ]
95
+
96
+ return list(zip(word_values, probs))
97
+
98
+
99
+ def _parseq(
100
+ arch: str,
101
+ model_path: str,
102
+ **kwargs: Any,
103
+ ) -> PARSeq:
104
+ # Patch the config
105
+ _cfg = deepcopy(default_cfgs[arch])
106
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
107
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
108
+
109
+ kwargs["vocab"] = _cfg["vocab"]
110
+
111
+ # Build the model
112
+ return PARSeq(model_path, cfg=_cfg, **kwargs)
113
+
114
+
115
+ def parseq(model_path: str = default_cfgs["parseq"]["url"], **kwargs: Any) -> PARSeq:
116
+ """PARSeq architecture from
117
+ `"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
118
+
119
+ >>> import numpy as np
120
+ >>> from onnxtr.models import parseq
121
+ >>> model = parseq()
122
+ >>> input_tensor = np.random.rand(1, 3, 32, 128)
123
+ >>> out = model(input_tensor)
124
+
125
+ Args:
126
+ ----
127
+ model_path: path to onnx model file, defaults to url in default_cfgs
128
+ **kwargs: keyword arguments of the PARSeq architecture
129
+
130
+ Returns:
131
+ -------
132
+ text recognition architecture
133
+ """
134
+ return _parseq("parseq", model_path, **kwargs)