onnxtr 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. onnxtr/contrib/__init__.py +1 -0
  2. onnxtr/contrib/artefacts.py +6 -8
  3. onnxtr/contrib/base.py +7 -16
  4. onnxtr/file_utils.py +1 -3
  5. onnxtr/io/elements.py +54 -60
  6. onnxtr/io/html.py +0 -2
  7. onnxtr/io/image.py +1 -4
  8. onnxtr/io/pdf.py +3 -5
  9. onnxtr/io/reader.py +4 -10
  10. onnxtr/models/_utils.py +10 -17
  11. onnxtr/models/builder.py +17 -30
  12. onnxtr/models/classification/models/mobilenet.py +7 -12
  13. onnxtr/models/classification/predictor/base.py +6 -7
  14. onnxtr/models/classification/zoo.py +25 -11
  15. onnxtr/models/detection/_utils/base.py +3 -7
  16. onnxtr/models/detection/core.py +2 -8
  17. onnxtr/models/detection/models/differentiable_binarization.py +10 -17
  18. onnxtr/models/detection/models/fast.py +10 -17
  19. onnxtr/models/detection/models/linknet.py +10 -17
  20. onnxtr/models/detection/postprocessor/base.py +3 -9
  21. onnxtr/models/detection/predictor/base.py +4 -5
  22. onnxtr/models/detection/zoo.py +20 -6
  23. onnxtr/models/engine.py +9 -9
  24. onnxtr/models/factory/hub.py +3 -7
  25. onnxtr/models/predictor/base.py +29 -30
  26. onnxtr/models/predictor/predictor.py +4 -5
  27. onnxtr/models/preprocessor/base.py +8 -12
  28. onnxtr/models/recognition/core.py +0 -1
  29. onnxtr/models/recognition/models/crnn.py +11 -23
  30. onnxtr/models/recognition/models/master.py +9 -15
  31. onnxtr/models/recognition/models/parseq.py +8 -12
  32. onnxtr/models/recognition/models/sar.py +8 -12
  33. onnxtr/models/recognition/models/vitstr.py +9 -15
  34. onnxtr/models/recognition/predictor/_utils.py +6 -9
  35. onnxtr/models/recognition/predictor/base.py +3 -3
  36. onnxtr/models/recognition/utils.py +2 -7
  37. onnxtr/models/recognition/zoo.py +19 -7
  38. onnxtr/models/zoo.py +7 -9
  39. onnxtr/transforms/base.py +17 -6
  40. onnxtr/utils/common_types.py +7 -8
  41. onnxtr/utils/data.py +7 -11
  42. onnxtr/utils/fonts.py +1 -6
  43. onnxtr/utils/geometry.py +18 -49
  44. onnxtr/utils/multithreading.py +3 -5
  45. onnxtr/utils/reconstitution.py +139 -38
  46. onnxtr/utils/repr.py +1 -2
  47. onnxtr/utils/visualization.py +12 -21
  48. onnxtr/utils/vocabs.py +1 -2
  49. onnxtr/version.py +1 -1
  50. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/METADATA +71 -41
  51. onnxtr-0.6.0.dist-info/RECORD +75 -0
  52. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/WHEEL +1 -1
  53. onnxtr-0.5.0.dist-info/RECORD +0 -75
  54. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/LICENSE +0 -0
  55. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/top_level.txt +0 -0
  56. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/zip-safe +0 -0
@@ -3,7 +3,8 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional, Tuple
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
 
@@ -24,7 +25,6 @@ class _OCRPredictor:
24
25
  """Implements an object able to localize and identify text elements in a set of documents
25
26
 
26
27
  Args:
27
- ----
28
28
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
29
29
  without rotated textual elements.
30
30
  straighten_pages: if True, estimates the page general orientation based on the median line orientation.
@@ -39,8 +39,8 @@ class _OCRPredictor:
39
39
  **kwargs: keyword args of `DocumentBuilder`
40
40
  """
41
41
 
42
- crop_orientation_predictor: Optional[OrientationPredictor]
43
- page_orientation_predictor: Optional[OrientationPredictor]
42
+ crop_orientation_predictor: OrientationPredictor | None
43
+ page_orientation_predictor: OrientationPredictor | None
44
44
 
45
45
  def __init__(
46
46
  self,
@@ -50,7 +50,7 @@ class _OCRPredictor:
50
50
  symmetric_pad: bool = True,
51
51
  detect_orientation: bool = False,
52
52
  load_in_8_bit: bool = False,
53
- clf_engine_cfg: Optional[EngineConfig] = None,
53
+ clf_engine_cfg: EngineConfig | None = None,
54
54
  **kwargs: Any,
55
55
  ) -> None:
56
56
  self.assume_straight_pages = assume_straight_pages
@@ -74,12 +74,12 @@ class _OCRPredictor:
74
74
  self.doc_builder = DocumentBuilder(**kwargs)
75
75
  self.preserve_aspect_ratio = preserve_aspect_ratio
76
76
  self.symmetric_pad = symmetric_pad
77
- self.hooks: List[Callable] = []
77
+ self.hooks: list[Callable] = []
78
78
 
79
79
  def _general_page_orientations(
80
80
  self,
81
- pages: List[np.ndarray],
82
- ) -> List[Tuple[int, float]]:
81
+ pages: list[np.ndarray],
82
+ ) -> list[tuple[int, float]]:
83
83
  _, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
84
84
  # Flatten to list of tuples with (value, confidence)
85
85
  page_orientations = [
@@ -90,8 +90,8 @@ class _OCRPredictor:
90
90
  return page_orientations
91
91
 
92
92
  def _get_orientations(
93
- self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
94
- ) -> Tuple[List[Tuple[int, float]], List[int]]:
93
+ self, pages: list[np.ndarray], seg_maps: list[np.ndarray]
94
+ ) -> tuple[list[tuple[int, float]], list[int]]:
95
95
  general_pages_orientations = self._general_page_orientations(pages)
96
96
  origin_page_orientations = [
97
97
  estimate_orientation(seq_map, general_orientation)
@@ -101,11 +101,11 @@ class _OCRPredictor:
101
101
 
102
102
  def _straighten_pages(
103
103
  self,
104
- pages: List[np.ndarray],
105
- seg_maps: List[np.ndarray],
106
- general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
107
- origin_pages_orientations: Optional[List[int]] = None,
108
- ) -> List[np.ndarray]:
104
+ pages: list[np.ndarray],
105
+ seg_maps: list[np.ndarray],
106
+ general_pages_orientations: list[tuple[int, float]] | None = None,
107
+ origin_pages_orientations: list[int] | None = None,
108
+ ) -> list[np.ndarray]:
109
109
  general_pages_orientations = (
110
110
  general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
111
111
  )
@@ -125,12 +125,12 @@ class _OCRPredictor:
125
125
 
126
126
  @staticmethod
127
127
  def _generate_crops(
128
- pages: List[np.ndarray],
129
- loc_preds: List[np.ndarray],
128
+ pages: list[np.ndarray],
129
+ loc_preds: list[np.ndarray],
130
130
  channels_last: bool,
131
131
  assume_straight_pages: bool = False,
132
132
  assume_horizontal: bool = False,
133
- ) -> List[List[np.ndarray]]:
133
+ ) -> list[list[np.ndarray]]:
134
134
  if assume_straight_pages:
135
135
  crops = [
136
136
  extract_crops(page, _boxes[:, :4], channels_last=channels_last)
@@ -145,12 +145,12 @@ class _OCRPredictor:
145
145
 
146
146
  @staticmethod
147
147
  def _prepare_crops(
148
- pages: List[np.ndarray],
149
- loc_preds: List[np.ndarray],
148
+ pages: list[np.ndarray],
149
+ loc_preds: list[np.ndarray],
150
150
  channels_last: bool,
151
151
  assume_straight_pages: bool = False,
152
152
  assume_horizontal: bool = False,
153
- ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
153
+ ) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
154
154
  crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
155
155
 
156
156
  # Avoid sending zero-sized crops
@@ -165,9 +165,9 @@ class _OCRPredictor:
165
165
 
166
166
  def _rectify_crops(
167
167
  self,
168
- crops: List[List[np.ndarray]],
169
- loc_preds: List[np.ndarray],
170
- ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]:
168
+ crops: list[list[np.ndarray]],
169
+ loc_preds: list[np.ndarray],
170
+ ) -> tuple[list[list[np.ndarray]], list[np.ndarray], list[tuple[int, float]]]:
171
171
  # Work at a page level
172
172
  orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
173
173
  rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
@@ -185,10 +185,10 @@ class _OCRPredictor:
185
185
 
186
186
  @staticmethod
187
187
  def _process_predictions(
188
- loc_preds: List[np.ndarray],
189
- word_preds: List[Tuple[str, float]],
190
- crop_orientations: List[Dict[str, Any]],
191
- ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]:
188
+ loc_preds: list[np.ndarray],
189
+ word_preds: list[tuple[str, float]],
190
+ crop_orientations: list[dict[str, Any]],
191
+ ) -> tuple[list[np.ndarray], list[list[tuple[str, float]]], list[list[dict[str, Any]]]]:
192
192
  text_preds = []
193
193
  crop_orientation_preds = []
194
194
  if len(loc_preds) > 0:
@@ -205,10 +205,9 @@ class _OCRPredictor:
205
205
  """Add a hook to the predictor
206
206
 
207
207
  Args:
208
- ----
209
208
  hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
210
209
  """
211
210
  self.hooks.append(hook)
212
211
 
213
- def list_archs(self) -> Dict[str, List[str]]:
212
+ def list_archs(self) -> dict[str, list[str]]:
214
213
  return {"detection_archs": DETECTION_ARCHS, "recognition_archs": RECOGNITION_ARCHS}
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Optional
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
 
@@ -24,7 +24,6 @@ class OCRPredictor(NestedObject, _OCRPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -52,7 +51,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
52
51
  symmetric_pad: bool = True,
53
52
  detect_orientation: bool = False,
54
53
  detect_language: bool = False,
55
- clf_engine_cfg: Optional[EngineConfig] = None,
54
+ clf_engine_cfg: EngineConfig | None = None,
56
55
  **kwargs: Any,
57
56
  ) -> None:
58
57
  self.det_predictor = det_predictor
@@ -72,7 +71,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
72
71
 
73
72
  def __call__(
74
73
  self,
75
- pages: List[np.ndarray],
74
+ pages: list[np.ndarray],
76
75
  **kwargs: Any,
77
76
  ) -> Document:
78
77
  # Dimension check
@@ -147,7 +146,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
147
146
  boxes,
148
147
  objectness_scores,
149
148
  text_preds,
150
- origin_page_shapes, # type: ignore[arg-type]
149
+ origin_page_shapes,
151
150
  crop_orientations,
152
151
  orientations,
153
152
  languages_dict,
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, List, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
 
@@ -20,36 +20,34 @@ class PreProcessor(NestedObject):
20
20
  """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
21
21
 
22
22
  Args:
23
- ----
24
23
  output_size: expected size of each page in format (H, W)
25
24
  batch_size: the size of page batches
26
25
  mean: mean value of the training distribution by channel
27
26
  std: standard deviation of the training distribution by channel
27
+ **kwargs: additional arguments for the resizing operation
28
28
  """
29
29
 
30
- _children_names: List[str] = ["resize", "normalize"]
30
+ _children_names: list[str] = ["resize", "normalize"]
31
31
 
32
32
  def __init__(
33
33
  self,
34
- output_size: Tuple[int, int],
34
+ output_size: tuple[int, int],
35
35
  batch_size: int,
36
- mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
37
- std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
36
+ mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
37
+ std: tuple[float, float, float] = (1.0, 1.0, 1.0),
38
38
  **kwargs: Any,
39
39
  ) -> None:
40
40
  self.batch_size = batch_size
41
41
  self.resize = Resize(output_size, **kwargs)
42
42
  self.normalize = Normalize(mean, std)
43
43
 
44
- def batch_inputs(self, samples: List[np.ndarray]) -> List[np.ndarray]:
44
+ def batch_inputs(self, samples: list[np.ndarray]) -> list[np.ndarray]:
45
45
  """Gather samples into batches for inference purposes
46
46
 
47
47
  Args:
48
- ----
49
48
  samples: list of samples (tf.Tensor)
50
49
 
51
50
  Returns:
52
- -------
53
51
  list of batched samples
54
52
  """
55
53
  num_batches = int(math.ceil(len(samples) / self.batch_size))
@@ -76,15 +74,13 @@ class PreProcessor(NestedObject):
76
74
 
77
75
  return x
78
76
 
79
- def __call__(self, x: Union[np.ndarray, List[np.ndarray]]) -> List[np.ndarray]:
77
+ def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[np.ndarray]:
80
78
  """Prepare document data for model forwarding
81
79
 
82
80
  Args:
83
- ----
84
81
  x: list of images (np.array) or tensors (already resized and batched)
85
82
 
86
83
  Returns:
87
- -------
88
84
  list of page batches
89
85
  """
90
86
  # Input type check
@@ -13,7 +13,6 @@ class RecognitionPostProcessor(NestedObject):
13
13
  """Abstract class to postprocess the raw output of the model
14
14
 
15
15
  Args:
16
- ----
17
16
  vocab: string containing the ordered sequence of supported characters
18
17
  """
19
18
 
@@ -5,7 +5,7 @@
5
5
 
6
6
  from copy import deepcopy
7
7
  from itertools import groupby
8
- from typing import Any, Dict, List, Optional
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
  from scipy.special import softmax
@@ -17,7 +17,7 @@ from ..core import RecognitionPostProcessor
17
17
 
18
18
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
19
19
 
20
- default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ default_cfgs: dict[str, dict[str, Any]] = {
21
21
  "crnn_vgg16_bn": {
22
22
  "mean": (0.694, 0.695, 0.693),
23
23
  "std": (0.299, 0.296, 0.301),
@@ -49,7 +49,6 @@ class CRNNPostProcessor(RecognitionPostProcessor):
49
49
  """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
50
50
 
51
51
  Args:
52
- ----
53
52
  vocab: string containing the ordered sequence of supported characters
54
53
  """
55
54
 
@@ -69,13 +68,11 @@ class CRNNPostProcessor(RecognitionPostProcessor):
69
68
  <https://github.com/githubharald/CTCDecoder>`_.
70
69
 
71
70
  Args:
72
- ----
73
71
  logits: model output, shape: N x T x C
74
72
  vocab: vocabulary to use
75
73
  blank: index of blank label
76
74
 
77
75
  Returns:
78
- -------
79
76
  A list of tuples: (word, confidence)
80
77
  """
81
78
  # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
@@ -94,11 +91,9 @@ class CRNNPostProcessor(RecognitionPostProcessor):
94
91
  with label_to_idx mapping dictionnary
95
92
 
96
93
  Args:
97
- ----
98
94
  logits: raw output of the model, shape (N, C + 1, seq_len)
99
95
 
100
96
  Returns:
101
- -------
102
97
  A tuple of 2 lists: a list of str (words) and a list of float (probs)
103
98
 
104
99
  """
@@ -110,7 +105,6 @@ class CRNN(Engine):
110
105
  """CRNN Onnx loader
111
106
 
112
107
  Args:
113
- ----
114
108
  model_path: path or url to onnx model file
115
109
  vocab: vocabulary used for encoding
116
110
  engine_cfg: configuration for the inference engine
@@ -118,14 +112,14 @@ class CRNN(Engine):
118
112
  **kwargs: additional arguments to be passed to `Engine`
119
113
  """
120
114
 
121
- _children_names: List[str] = ["postprocessor"]
115
+ _children_names: list[str] = ["postprocessor"]
122
116
 
123
117
  def __init__(
124
118
  self,
125
119
  model_path: str,
126
120
  vocab: str,
127
- engine_cfg: Optional[EngineConfig] = None,
128
- cfg: Optional[Dict[str, Any]] = None,
121
+ engine_cfg: EngineConfig | None = None,
122
+ cfg: dict[str, Any] | None = None,
129
123
  **kwargs: Any,
130
124
  ) -> None:
131
125
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
@@ -139,10 +133,10 @@ class CRNN(Engine):
139
133
  self,
140
134
  x: np.ndarray,
141
135
  return_model_output: bool = False,
142
- ) -> Dict[str, Any]:
136
+ ) -> dict[str, Any]:
143
137
  logits = self.run(x)
144
138
 
145
- out: Dict[str, Any] = {}
139
+ out: dict[str, Any] = {}
146
140
  if return_model_output:
147
141
  out["out_map"] = logits
148
142
 
@@ -156,7 +150,7 @@ def _crnn(
156
150
  arch: str,
157
151
  model_path: str,
158
152
  load_in_8_bit: bool = False,
159
- engine_cfg: Optional[EngineConfig] = None,
153
+ engine_cfg: EngineConfig | None = None,
160
154
  **kwargs: Any,
161
155
  ) -> CRNN:
162
156
  kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
@@ -174,7 +168,7 @@ def _crnn(
174
168
  def crnn_vgg16_bn(
175
169
  model_path: str = default_cfgs["crnn_vgg16_bn"]["url"],
176
170
  load_in_8_bit: bool = False,
177
- engine_cfg: Optional[EngineConfig] = None,
171
+ engine_cfg: EngineConfig | None = None,
178
172
  **kwargs: Any,
179
173
  ) -> CRNN:
180
174
  """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
@@ -187,14 +181,12 @@ def crnn_vgg16_bn(
187
181
  >>> out = model(input_tensor)
188
182
 
189
183
  Args:
190
- ----
191
184
  model_path: path to onnx model file, defaults to url in default_cfgs
192
185
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
193
186
  engine_cfg: configuration for the inference engine
194
187
  **kwargs: keyword arguments of the CRNN architecture
195
188
 
196
189
  Returns:
197
- -------
198
190
  text recognition architecture
199
191
  """
200
192
  return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -203,7 +195,7 @@ def crnn_vgg16_bn(
203
195
  def crnn_mobilenet_v3_small(
204
196
  model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"],
205
197
  load_in_8_bit: bool = False,
206
- engine_cfg: Optional[EngineConfig] = None,
198
+ engine_cfg: EngineConfig | None = None,
207
199
  **kwargs: Any,
208
200
  ) -> CRNN:
209
201
  """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
@@ -216,14 +208,12 @@ def crnn_mobilenet_v3_small(
216
208
  >>> out = model(input_tensor)
217
209
 
218
210
  Args:
219
- ----
220
211
  model_path: path to onnx model file, defaults to url in default_cfgs
221
212
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
222
213
  engine_cfg: configuration for the inference engine
223
214
  **kwargs: keyword arguments of the CRNN architecture
224
215
 
225
216
  Returns:
226
- -------
227
217
  text recognition architecture
228
218
  """
229
219
  return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -232,7 +222,7 @@ def crnn_mobilenet_v3_small(
232
222
  def crnn_mobilenet_v3_large(
233
223
  model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"],
234
224
  load_in_8_bit: bool = False,
235
- engine_cfg: Optional[EngineConfig] = None,
225
+ engine_cfg: EngineConfig | None = None,
236
226
  **kwargs: Any,
237
227
  ) -> CRNN:
238
228
  """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
@@ -245,14 +235,12 @@ def crnn_mobilenet_v3_large(
245
235
  >>> out = model(input_tensor)
246
236
 
247
237
  Args:
248
- ----
249
238
  model_path: path to onnx model file, defaults to url in default_cfgs
250
239
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
251
240
  engine_cfg: configuration for the inference engine
252
241
  **kwargs: keyword arguments of the CRNN architecture
253
242
 
254
243
  Returns:
255
- -------
256
244
  text recognition architecture
257
245
  """
258
246
  return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  from scipy.special import softmax
@@ -17,7 +17,7 @@ from ..core import RecognitionPostProcessor
17
17
  __all__ = ["MASTER", "master"]
18
18
 
19
19
 
20
- default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ default_cfgs: dict[str, dict[str, Any]] = {
21
21
  "master": {
22
22
  "mean": (0.694, 0.695, 0.693),
23
23
  "std": (0.299, 0.296, 0.301),
@@ -33,7 +33,6 @@ class MASTER(Engine):
33
33
  """MASTER Onnx loader
34
34
 
35
35
  Args:
36
- ----
37
36
  model_path: path or url to onnx model file
38
37
  vocab: vocabulary, (without EOS, SOS, PAD)
39
38
  engine_cfg: configuration for the inference engine
@@ -45,8 +44,8 @@ class MASTER(Engine):
45
44
  self,
46
45
  model_path: str,
47
46
  vocab: str,
48
- engine_cfg: Optional[EngineConfig] = None,
49
- cfg: Optional[Dict[str, Any]] = None,
47
+ engine_cfg: EngineConfig | None = None,
48
+ cfg: dict[str, Any] | None = None,
50
49
  **kwargs: Any,
51
50
  ) -> None:
52
51
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
@@ -60,20 +59,18 @@ class MASTER(Engine):
60
59
  self,
61
60
  x: np.ndarray,
62
61
  return_model_output: bool = False,
63
- ) -> Dict[str, Any]:
62
+ ) -> dict[str, Any]:
64
63
  """Call function
65
64
 
66
65
  Args:
67
- ----
68
66
  x: images
69
67
  return_model_output: if True, return logits
70
68
 
71
69
  Returns:
72
- -------
73
70
  A dictionnary containing eventually logits and predictions.
74
71
  """
75
72
  logits = self.run(x)
76
- out: Dict[str, Any] = {}
73
+ out: dict[str, Any] = {}
77
74
 
78
75
  if return_model_output:
79
76
  out["out_map"] = logits
@@ -87,7 +84,6 @@ class MASTERPostProcessor(RecognitionPostProcessor):
87
84
  """Post-processor for the MASTER model
88
85
 
89
86
  Args:
90
- ----
91
87
  vocab: string containing the ordered sequence of supported characters
92
88
  """
93
89
 
@@ -98,7 +94,7 @@ class MASTERPostProcessor(RecognitionPostProcessor):
98
94
  super().__init__(vocab)
99
95
  self._embedding = list(vocab) + ["<eos>"] + ["<sos>"] + ["<pad>"]
100
96
 
101
- def __call__(self, logits: np.ndarray) -> List[Tuple[str, float]]:
97
+ def __call__(self, logits: np.ndarray) -> list[tuple[str, float]]:
102
98
  # compute pred with argmax for attention models
103
99
  out_idxs = np.argmax(logits, axis=-1)
104
100
  # N x L
@@ -117,7 +113,7 @@ def _master(
117
113
  arch: str,
118
114
  model_path: str,
119
115
  load_in_8_bit: bool = False,
120
- engine_cfg: Optional[EngineConfig] = None,
116
+ engine_cfg: EngineConfig | None = None,
121
117
  **kwargs: Any,
122
118
  ) -> MASTER:
123
119
  # Patch the config
@@ -135,7 +131,7 @@ def _master(
135
131
  def master(
136
132
  model_path: str = default_cfgs["master"]["url"],
137
133
  load_in_8_bit: bool = False,
138
- engine_cfg: Optional[EngineConfig] = None,
134
+ engine_cfg: EngineConfig | None = None,
139
135
  **kwargs: Any,
140
136
  ) -> MASTER:
141
137
  """MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
@@ -147,14 +143,12 @@ def master(
147
143
  >>> out = model(input_tensor)
148
144
 
149
145
  Args:
150
- ----
151
146
  model_path: path to onnx model file, defaults to url in default_cfgs
152
147
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
153
148
  engine_cfg: configuration for the inference engine
154
149
  **kwargs: keywoard arguments passed to the MASTER architecture
155
150
 
156
151
  Returns:
157
- -------
158
152
  text recognition architecture
159
153
  """
160
154
  return _master("master", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, Optional
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  from scipy.special import softmax
@@ -16,7 +16,7 @@ from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["PARSeq", "parseq"]
18
18
 
19
- default_cfgs: Dict[str, Dict[str, Any]] = {
19
+ default_cfgs: dict[str, dict[str, Any]] = {
20
20
  "parseq": {
21
21
  "mean": (0.694, 0.695, 0.693),
22
22
  "std": (0.299, 0.296, 0.301),
@@ -32,7 +32,6 @@ class PARSeq(Engine):
32
32
  """PARSeq Onnx loader
33
33
 
34
34
  Args:
35
- ----
36
35
  model_path: path to onnx model file
37
36
  vocab: vocabulary used for encoding
38
37
  engine_cfg: configuration for the inference engine
@@ -44,8 +43,8 @@ class PARSeq(Engine):
44
43
  self,
45
44
  model_path: str,
46
45
  vocab: str,
47
- engine_cfg: Optional[EngineConfig] = None,
48
- cfg: Optional[Dict[str, Any]] = None,
46
+ engine_cfg: EngineConfig | None = None,
47
+ cfg: dict[str, Any] | None = None,
49
48
  **kwargs: Any,
50
49
  ) -> None:
51
50
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
@@ -59,9 +58,9 @@ class PARSeq(Engine):
59
58
  self,
60
59
  x: np.ndarray,
61
60
  return_model_output: bool = False,
62
- ) -> Dict[str, Any]:
61
+ ) -> dict[str, Any]:
63
62
  logits = self.run(x)
64
- out: Dict[str, Any] = {}
63
+ out: dict[str, Any] = {}
65
64
 
66
65
  if return_model_output:
67
66
  out["out_map"] = logits
@@ -74,7 +73,6 @@ class PARSeqPostProcessor(RecognitionPostProcessor):
74
73
  """Post processor for PARSeq architecture
75
74
 
76
75
  Args:
77
- ----
78
76
  vocab: string containing the ordered sequence of supported characters
79
77
  """
80
78
 
@@ -106,7 +104,7 @@ def _parseq(
106
104
  arch: str,
107
105
  model_path: str,
108
106
  load_in_8_bit: bool = False,
109
- engine_cfg: Optional[EngineConfig] = None,
107
+ engine_cfg: EngineConfig | None = None,
110
108
  **kwargs: Any,
111
109
  ) -> PARSeq:
112
110
  # Patch the config
@@ -125,7 +123,7 @@ def _parseq(
125
123
  def parseq(
126
124
  model_path: str = default_cfgs["parseq"]["url"],
127
125
  load_in_8_bit: bool = False,
128
- engine_cfg: Optional[EngineConfig] = None,
126
+ engine_cfg: EngineConfig | None = None,
129
127
  **kwargs: Any,
130
128
  ) -> PARSeq:
131
129
  """PARSeq architecture from
@@ -138,14 +136,12 @@ def parseq(
138
136
  >>> out = model(input_tensor)
139
137
 
140
138
  Args:
141
- ----
142
139
  model_path: path to onnx model file, defaults to url in default_cfgs
143
140
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
144
141
  engine_cfg: configuration for the inference engine
145
142
  **kwargs: keyword arguments of the PARSeq architecture
146
143
 
147
144
  Returns:
148
- -------
149
145
  text recognition architecture
150
146
  """
151
147
  return _parseq("parseq", model_path, load_in_8_bit, engine_cfg, **kwargs)