onnxtr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. onnxtr/__init__.py +2 -0
  2. onnxtr/contrib/__init__.py +0 -0
  3. onnxtr/contrib/artefacts.py +131 -0
  4. onnxtr/contrib/base.py +105 -0
  5. onnxtr/file_utils.py +33 -0
  6. onnxtr/io/__init__.py +5 -0
  7. onnxtr/io/elements.py +455 -0
  8. onnxtr/io/html.py +28 -0
  9. onnxtr/io/image.py +56 -0
  10. onnxtr/io/pdf.py +42 -0
  11. onnxtr/io/reader.py +85 -0
  12. onnxtr/models/__init__.py +4 -0
  13. onnxtr/models/_utils.py +141 -0
  14. onnxtr/models/builder.py +355 -0
  15. onnxtr/models/classification/__init__.py +2 -0
  16. onnxtr/models/classification/models/__init__.py +1 -0
  17. onnxtr/models/classification/models/mobilenet.py +120 -0
  18. onnxtr/models/classification/predictor/__init__.py +1 -0
  19. onnxtr/models/classification/predictor/base.py +57 -0
  20. onnxtr/models/classification/zoo.py +76 -0
  21. onnxtr/models/detection/__init__.py +2 -0
  22. onnxtr/models/detection/core.py +101 -0
  23. onnxtr/models/detection/models/__init__.py +3 -0
  24. onnxtr/models/detection/models/differentiable_binarization.py +159 -0
  25. onnxtr/models/detection/models/fast.py +160 -0
  26. onnxtr/models/detection/models/linknet.py +160 -0
  27. onnxtr/models/detection/postprocessor/__init__.py +0 -0
  28. onnxtr/models/detection/postprocessor/base.py +144 -0
  29. onnxtr/models/detection/predictor/__init__.py +1 -0
  30. onnxtr/models/detection/predictor/base.py +54 -0
  31. onnxtr/models/detection/zoo.py +73 -0
  32. onnxtr/models/engine.py +50 -0
  33. onnxtr/models/predictor/__init__.py +1 -0
  34. onnxtr/models/predictor/base.py +175 -0
  35. onnxtr/models/predictor/predictor.py +145 -0
  36. onnxtr/models/preprocessor/__init__.py +1 -0
  37. onnxtr/models/preprocessor/base.py +118 -0
  38. onnxtr/models/recognition/__init__.py +2 -0
  39. onnxtr/models/recognition/core.py +28 -0
  40. onnxtr/models/recognition/models/__init__.py +5 -0
  41. onnxtr/models/recognition/models/crnn.py +226 -0
  42. onnxtr/models/recognition/models/master.py +145 -0
  43. onnxtr/models/recognition/models/parseq.py +134 -0
  44. onnxtr/models/recognition/models/sar.py +134 -0
  45. onnxtr/models/recognition/models/vitstr.py +166 -0
  46. onnxtr/models/recognition/predictor/__init__.py +1 -0
  47. onnxtr/models/recognition/predictor/_utils.py +86 -0
  48. onnxtr/models/recognition/predictor/base.py +79 -0
  49. onnxtr/models/recognition/utils.py +89 -0
  50. onnxtr/models/recognition/zoo.py +69 -0
  51. onnxtr/models/zoo.py +114 -0
  52. onnxtr/transforms/__init__.py +1 -0
  53. onnxtr/transforms/base.py +112 -0
  54. onnxtr/utils/__init__.py +4 -0
  55. onnxtr/utils/common_types.py +18 -0
  56. onnxtr/utils/data.py +126 -0
  57. onnxtr/utils/fonts.py +41 -0
  58. onnxtr/utils/geometry.py +498 -0
  59. onnxtr/utils/multithreading.py +50 -0
  60. onnxtr/utils/reconstitution.py +70 -0
  61. onnxtr/utils/repr.py +64 -0
  62. onnxtr/utils/visualization.py +291 -0
  63. onnxtr/utils/vocabs.py +71 -0
  64. onnxtr/version.py +1 -0
  65. onnxtr-0.1.0.dist-info/LICENSE +201 -0
  66. onnxtr-0.1.0.dist-info/METADATA +481 -0
  67. onnxtr-0.1.0.dist-info/RECORD +70 -0
  68. onnxtr-0.1.0.dist-info/WHEEL +5 -0
  69. onnxtr-0.1.0.dist-info/top_level.txt +2 -0
  70. onnxtr-0.1.0.dist-info/zip-safe +1 -0
@@ -0,0 +1,134 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from copy import deepcopy
7
+ from typing import Any, Dict, Optional
8
+
9
+ import numpy as np
10
+ from scipy.special import softmax
11
+
12
+ from onnxtr.utils import VOCABS
13
+
14
+ from ...engine import Engine
15
+ from ..core import RecognitionPostProcessor
16
+
17
+ __all__ = ["SAR", "sar_resnet31"]
18
+
19
+ default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ "sar_resnet31": {
21
+ "mean": (0.694, 0.695, 0.693),
22
+ "std": (0.299, 0.296, 0.301),
23
+ "input_shape": (3, 32, 128),
24
+ "vocab": VOCABS["french"],
25
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/sar_resnet31-395f8005.onnx",
26
+ },
27
+ }
28
+
29
+
30
+ class SAR(Engine):
31
+ """SAR Onnx loader
32
+
33
+ Args:
34
+ ----
35
+ model_path: path to onnx model file
36
+ vocab: vocabulary used for encoding
37
+ cfg: dictionary containing information about the model
38
+ **kwargs: additional arguments to be passed to `Engine`
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ model_path: str,
44
+ vocab: str,
45
+ cfg: Optional[Dict[str, Any]] = None,
46
+ **kwargs: Any,
47
+ ) -> None:
48
+ super().__init__(url=model_path, **kwargs)
49
+ self.vocab = vocab
50
+ self.cfg = cfg
51
+ self.postprocessor = SARPostProcessor(self.vocab)
52
+
53
+ def __call__(
54
+ self,
55
+ x: np.ndarray,
56
+ return_model_output: bool = False,
57
+ ) -> Dict[str, Any]:
58
+ logits = self.run(x)
59
+
60
+ out: Dict[str, Any] = {}
61
+ if return_model_output:
62
+ out["out_map"] = logits
63
+
64
+ out["preds"] = self.postprocessor(logits)
65
+
66
+ return out
67
+
68
+
69
+ class SARPostProcessor(RecognitionPostProcessor):
70
+ """Post processor for SAR architectures
71
+
72
+ Args:
73
+ ----
74
+ embedding: string containing the ordered sequence of supported characters
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ vocab: str,
80
+ ) -> None:
81
+ super().__init__(vocab)
82
+ self._embedding = list(self.vocab) + ["<eos>"]
83
+
84
+ def __call__(self, logits):
85
+ # compute pred with argmax for attention models
86
+ out_idxs = np.argmax(logits, axis=-1)
87
+ # N x L
88
+ probs = np.take_along_axis(softmax(logits, axis=-1), out_idxs[..., None], axis=-1).squeeze(-1)
89
+ # Take the minimum confidence of the sequence
90
+ probs = np.min(probs, axis=1)
91
+
92
+ word_values = [
93
+ "".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0] for encoded_seq in out_idxs
94
+ ]
95
+
96
+ return list(zip(word_values, np.clip(probs, 0, 1).astype(float).tolist()))
97
+
98
+
99
+ def _sar(
100
+ arch: str,
101
+ model_path: str,
102
+ **kwargs: Any,
103
+ ) -> SAR:
104
+ # Patch the config
105
+ _cfg = deepcopy(default_cfgs[arch])
106
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
107
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
108
+
109
+ kwargs["vocab"] = _cfg["vocab"]
110
+
111
+ # Build the model
112
+ return SAR(model_path, cfg=_cfg, **kwargs)
113
+
114
+
115
+ def sar_resnet31(model_path: str = default_cfgs["sar_resnet31"]["url"], **kwargs: Any) -> SAR:
116
+ """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
117
+ Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
118
+
119
+ >>> import numpy as np
120
+ >>> from onnxtr.models import sar_resnet31
121
+ >>> model = sar_resnet31()
122
+ >>> input_tensor = np.random.rand(1, 3, 32, 128)
123
+ >>> out = model(input_tensor)
124
+
125
+ Args:
126
+ ----
127
+ model_path: path to onnx model file, defaults to url in default_cfgs
128
+ **kwargs: keyword arguments of the SAR architecture
129
+
130
+ Returns:
131
+ -------
132
+ text recognition architecture
133
+ """
134
+ return _sar("sar_resnet31", model_path, **kwargs)
@@ -0,0 +1,166 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from copy import deepcopy
7
+ from typing import Any, Dict, Optional
8
+
9
+ import numpy as np
10
+ from scipy.special import softmax
11
+
12
+ from onnxtr.utils import VOCABS
13
+
14
+ from ...engine import Engine
15
+ from ..core import RecognitionPostProcessor
16
+
17
+ __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
18
+
19
+ default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ "vitstr_small": {
21
+ "mean": (0.694, 0.695, 0.693),
22
+ "std": (0.299, 0.296, 0.301),
23
+ "input_shape": (3, 32, 128),
24
+ "vocab": VOCABS["french"],
25
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/vitstr_small-3ff9c500.onnx",
26
+ },
27
+ "vitstr_base": {
28
+ "mean": (0.694, 0.695, 0.693),
29
+ "std": (0.299, 0.296, 0.301),
30
+ "input_shape": (3, 32, 128),
31
+ "vocab": VOCABS["french"],
32
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/vitstr_base-ff62f5be.onnx",
33
+ },
34
+ }
35
+
36
+
37
+ class ViTSTR(Engine):
38
+ """ViTSTR Onnx loader
39
+
40
+ Args:
41
+ ----
42
+ model_path: path to onnx model file
43
+ vocab: vocabulary used for encoding
44
+ cfg: dictionary containing information about the model
45
+ **kwargs: additional arguments to be passed to `Engine`
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ model_path: str,
51
+ vocab: str,
52
+ cfg: Optional[Dict[str, Any]] = None,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ super().__init__(url=model_path, **kwargs)
56
+ self.vocab = vocab
57
+ self.cfg = cfg
58
+
59
+ self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
60
+
61
+ def __call__(
62
+ self,
63
+ x: np.ndarray,
64
+ return_model_output: bool = False,
65
+ ) -> Dict[str, Any]:
66
+ logits = self.run(x)
67
+
68
+ out: Dict[str, Any] = {}
69
+ if return_model_output:
70
+ out["out_map"] = logits
71
+
72
+ out["preds"] = self.postprocessor(logits)
73
+
74
+ return out
75
+
76
+
77
+ class ViTSTRPostProcessor(RecognitionPostProcessor):
78
+ """Post processor for ViTSTR architecture
79
+
80
+ Args:
81
+ ----
82
+ vocab: string containing the ordered sequence of supported characters
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ vocab: str,
88
+ ) -> None:
89
+ super().__init__(vocab)
90
+ self._embedding = list(vocab) + ["<eos>", "<sos>"]
91
+
92
+ def __call__(self, logits):
93
+ # compute pred with argmax for attention models
94
+ out_idxs = np.argmax(logits, axis=-1)
95
+ preds_prob = softmax(logits, axis=-1).max(axis=-1)
96
+
97
+ word_values = [
98
+ "".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0] for encoded_seq in out_idxs
99
+ ]
100
+ # compute probabilties for each word up to the EOS token
101
+ probs = [
102
+ preds_prob[i, : len(word)].clip(0, 1).mean().astype(float) if word else 0.0
103
+ for i, word in enumerate(word_values)
104
+ ]
105
+
106
+ return list(zip(word_values, probs))
107
+
108
+
109
+ def _vitstr(
110
+ arch: str,
111
+ model_path: str,
112
+ **kwargs: Any,
113
+ ) -> ViTSTR:
114
+ # Patch the config
115
+ _cfg = deepcopy(default_cfgs[arch])
116
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
117
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
118
+
119
+ kwargs["vocab"] = _cfg["vocab"]
120
+
121
+ # Build the model
122
+ return ViTSTR(model_path, cfg=_cfg, **kwargs)
123
+
124
+
125
+ def vitstr_small(model_path: str = default_cfgs["vitstr_small"]["url"], **kwargs: Any) -> ViTSTR:
126
+ """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
127
+ <https://arxiv.org/pdf/2105.08582.pdf>`_.
128
+
129
+ >>> import numpy as np
130
+ >>> from onnxtr.models import vitstr_small
131
+ >>> model = vitstr_small()
132
+ >>> input_tensor = np.random.rand(1, 3, 32, 128)
133
+ >>> out = model(input_tensor)
134
+
135
+ Args:
136
+ ----
137
+ model_path: path to onnx model file, defaults to url in default_cfgs
138
+ kwargs: keyword arguments of the ViTSTR architecture
139
+
140
+ Returns:
141
+ -------
142
+ text recognition architecture
143
+ """
144
+ return _vitstr("vitstr_small", model_path, **kwargs)
145
+
146
+
147
+ def vitstr_base(model_path: str = default_cfgs["vitstr_base"]["url"], **kwargs: Any) -> ViTSTR:
148
+ """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
149
+ <https://arxiv.org/pdf/2105.08582.pdf>`_.
150
+
151
+ >>> import numpy as np
152
+ >>> from onnxtr.models import vitstr_base
153
+ >>> model = vitstr_base()
154
+ >>> input_tensor = np.random.rand(1, 3, 32, 128)
155
+ >>> out = model(input_tensor)
156
+
157
+ Args:
158
+ ----
159
+ model_path: path to onnx model file, defaults to url in default_cfgs
160
+ kwargs: keyword arguments of the ViTSTR architecture
161
+
162
+ Returns:
163
+ -------
164
+ text recognition architecture
165
+ """
166
+ return _vitstr("vitstr_base", model_path, **kwargs)
@@ -0,0 +1 @@
1
+ from .base import *
@@ -0,0 +1,86 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import List, Tuple, Union
7
+
8
+ import numpy as np
9
+
10
+ from ..utils import merge_multi_strings
11
+
12
+ __all__ = ["split_crops", "remap_preds"]
13
+
14
+
15
+ def split_crops(
16
+ crops: List[np.ndarray],
17
+ max_ratio: float,
18
+ target_ratio: int,
19
+ dilation: float,
20
+ channels_last: bool = True,
21
+ ) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]:
22
+ """Chunk crops horizontally to match a given aspect ratio
23
+
24
+ Args:
25
+ ----
26
+ crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
27
+ max_ratio: the maximum aspect ratio that won't trigger the chunk
28
+ target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
29
+ dilation: the width dilation of final chunks (to provide some overlaps)
30
+ channels_last: whether the numpy array has dimensions in channels last order
31
+
32
+ Returns:
33
+ -------
34
+ a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
35
+ """
36
+ _remap_required = False
37
+ crop_map: List[Union[int, Tuple[int, int]]] = []
38
+ new_crops: List[np.ndarray] = []
39
+ for crop in crops:
40
+ h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
41
+ aspect_ratio = w / h
42
+ if aspect_ratio > max_ratio:
43
+ # Determine the number of crops, reference aspect ratio = 4 = 128 / 32
44
+ num_subcrops = int(aspect_ratio // target_ratio)
45
+ # Find the new widths, additional dilation factor to overlap crops
46
+ width = dilation * w / num_subcrops
47
+ centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)]
48
+ # Get the crops
49
+ if channels_last:
50
+ _crops = [
51
+ crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :]
52
+ for center in centers
53
+ ]
54
+ else:
55
+ _crops = [
56
+ crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))]
57
+ for center in centers
58
+ ]
59
+ # Avoid sending zero-sized crops
60
+ _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
61
+ # Record the slice of crops
62
+ crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
63
+ new_crops.extend(_crops)
64
+ # At least one crop will require merging
65
+ _remap_required = True
66
+ else:
67
+ crop_map.append(len(new_crops))
68
+ new_crops.append(crop)
69
+
70
+ return new_crops, crop_map, _remap_required
71
+
72
+
73
+ def remap_preds(
74
+ preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float
75
+ ) -> List[Tuple[str, float]]:
76
+ remapped_out = []
77
+ for _idx in crop_map:
78
+ # Crop hasn't been split
79
+ if isinstance(_idx, int):
80
+ remapped_out.append(preds[_idx])
81
+ else:
82
+ # unzip
83
+ vals, probs = zip(*preds[_idx[0] : _idx[1]])
84
+ # Merge the string values
85
+ remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type]
86
+ return remapped_out
@@ -0,0 +1,79 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List, Sequence, Tuple
7
+
8
+ import numpy as np
9
+
10
+ from onnxtr.models.preprocessor import PreProcessor
11
+ from onnxtr.utils.repr import NestedObject
12
+
13
+ from ._utils import remap_preds, split_crops
14
+
15
+ __all__ = ["RecognitionPredictor"]
16
+
17
+
18
+ class RecognitionPredictor(NestedObject):
19
+ """Implements an object able to identify character sequences in images
20
+
21
+ Args:
22
+ ----
23
+ pre_processor: transform inputs for easier batched model inference
24
+ model: core recognition architecture
25
+ split_wide_crops: wether to use crop splitting for high aspect ratio crops
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ pre_processor: PreProcessor,
31
+ model: Any,
32
+ split_wide_crops: bool = True,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.pre_processor = pre_processor
36
+ self.model = model
37
+ self.split_wide_crops = split_wide_crops
38
+ self.critical_ar = 8 # Critical aspect ratio
39
+ self.dil_factor = 1.4 # Dilation factor to overlap the crops
40
+ self.target_ar = 6 # Target aspect ratio
41
+
42
+ def __call__(
43
+ self,
44
+ crops: Sequence[np.ndarray],
45
+ **kwargs: Any,
46
+ ) -> List[Tuple[str, float]]:
47
+ if len(crops) == 0:
48
+ return []
49
+ # Dimension check
50
+ if any(crop.ndim != 3 for crop in crops):
51
+ raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
52
+
53
+ # Split crops that are too wide
54
+ remapped = False
55
+ if self.split_wide_crops:
56
+ new_crops, crop_map, remapped = split_crops(
57
+ crops, # type: ignore[arg-type]
58
+ self.critical_ar,
59
+ self.target_ar,
60
+ self.dil_factor,
61
+ True,
62
+ )
63
+ if remapped:
64
+ crops = new_crops
65
+
66
+ # Resize & batch them
67
+ processed_batches = self.pre_processor(crops) # type: ignore[arg-type]
68
+
69
+ # Forward it
70
+ raw = [self.model(batch, **kwargs)["preds"] for batch in processed_batches]
71
+
72
+ # Process outputs
73
+ out = [charseq for batch in raw for charseq in batch]
74
+
75
+ # Remap crops
76
+ if self.split_wide_crops and remapped:
77
+ out = remap_preds(out, crop_map, self.dil_factor)
78
+
79
+ return out
@@ -0,0 +1,89 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import List
7
+
8
+ from rapidfuzz.distance import Levenshtein
9
+
10
+ __all__ = ["merge_strings", "merge_multi_strings"]
11
+
12
+
13
+ def merge_strings(a: str, b: str, dil_factor: float) -> str:
14
+ """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
15
+
16
+ Args:
17
+ ----
18
+ a: first char seq, suffix should be similar to b's prefix.
19
+ b: second char seq, prefix should be similar to a's suffix.
20
+ dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
21
+ only used when the mother sequence is splitted on a character repetition
22
+
23
+ Returns:
24
+ -------
25
+ A merged character sequence.
26
+
27
+ Example::
28
+ >>> from onnxtr.model.recognition.utils import merge_sequences
29
+ >>> merge_sequences('abcd', 'cdefgh', 1.4)
30
+ 'abcdefgh'
31
+ >>> merge_sequences('abcdi', 'cdefgh', 1.4)
32
+ 'abcdefgh'
33
+ """
34
+ seq_len = min(len(a), len(b))
35
+ if seq_len == 0: # One sequence is empty, return the other
36
+ return b if len(a) == 0 else a
37
+
38
+ # Initialize merging index and corresponding score (mean Levenstein)
39
+ min_score, index = 1.0, 0 # No overlap, just concatenate
40
+
41
+ scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
42
+
43
+ # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
44
+ if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
45
+ # Compute n_overlap (number of overlapping chars, geometrically determined)
46
+ n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
47
+ # Find the number of consecutive zeros in the scores list
48
+ # Impossible to have a zero after a non-zero score in that case
49
+ n_zeros = sum(val == 0 for val in scores)
50
+ # Index is bounded by the geometrical overlap to avoid collapsing repetitions
51
+ min_score, index = 0, min(n_zeros, n_overlap)
52
+
53
+ else: # Common case: choose the min score index
54
+ for i, score in enumerate(scores):
55
+ if score < min_score:
56
+ min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
57
+
58
+ # Merge with correct overlap
59
+ if index == 0:
60
+ return a + b
61
+ return a[:-1] + b[index - 1 :]
62
+
63
+
64
+ def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
65
+ """Recursively merges consecutive string sequences with overlapping characters.
66
+
67
+ Args:
68
+ ----
69
+ seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
70
+ dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
71
+ only used when the mother sequence is splitted on a character repetition
72
+
73
+ Returns:
74
+ -------
75
+ A merged character sequence
76
+
77
+ Example::
78
+ >>> from onnxtr.model.recognition.utils import merge_multi_sequences
79
+ >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
80
+ 'abcdefghijkl'
81
+ """
82
+
83
+ def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str:
84
+ # Recursive version of compute_overlap
85
+ if len(seq_list) == 1:
86
+ return merge_strings(a, seq_list[0], dil_factor)
87
+ return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor)
88
+
89
+ return _recursive_merge("", seq_list, dil_factor)
@@ -0,0 +1,69 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List
7
+
8
+ from onnxtr.models.preprocessor import PreProcessor
9
+
10
+ from .. import recognition
11
+ from .predictor import RecognitionPredictor
12
+
13
+ __all__ = ["recognition_predictor"]
14
+
15
+
16
+ ARCHS: List[str] = [
17
+ "crnn_vgg16_bn",
18
+ "crnn_mobilenet_v3_small",
19
+ "crnn_mobilenet_v3_large",
20
+ "sar_resnet31",
21
+ "master",
22
+ "vitstr_small",
23
+ "vitstr_base",
24
+ "parseq",
25
+ ]
26
+
27
+
28
+ def _predictor(arch: Any, **kwargs: Any) -> RecognitionPredictor:
29
+ if isinstance(arch, str):
30
+ if arch not in ARCHS:
31
+ raise ValueError(f"unknown architecture '{arch}'")
32
+
33
+ _model = recognition.__dict__[arch]()
34
+ else:
35
+ if not isinstance(
36
+ arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
37
+ ):
38
+ raise ValueError(f"unknown architecture: {type(arch)}")
39
+ _model = arch
40
+
41
+ kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
42
+ kwargs["std"] = kwargs.get("std", _model.cfg["std"])
43
+ kwargs["batch_size"] = kwargs.get("batch_size", 1024)
44
+ input_shape = _model.cfg["input_shape"][1:]
45
+ predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
46
+
47
+ return predictor
48
+
49
+
50
+ def recognition_predictor(arch: Any = "crnn_vgg16_bn", **kwargs: Any) -> RecognitionPredictor:
51
+ """Text recognition architecture.
52
+
53
+ Example::
54
+ >>> import numpy as np
55
+ >>> from onnxtr.models import recognition_predictor
56
+ >>> model = recognition_predictor()
57
+ >>> input_page = (255 * np.random.rand(32, 128, 3)).astype(np.uint8)
58
+ >>> out = model([input_page])
59
+
60
+ Args:
61
+ ----
62
+ arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
63
+ **kwargs: optional parameters to be passed to the architecture
64
+
65
+ Returns:
66
+ -------
67
+ Recognition predictor
68
+ """
69
+ return _predictor(arch, **kwargs)