onnxtr 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. onnxtr/io/elements.py +17 -4
  2. onnxtr/io/pdf.py +6 -3
  3. onnxtr/models/__init__.py +1 -0
  4. onnxtr/models/_utils.py +57 -20
  5. onnxtr/models/builder.py +24 -9
  6. onnxtr/models/classification/models/mobilenet.py +12 -5
  7. onnxtr/models/classification/zoo.py +18 -6
  8. onnxtr/models/detection/_utils/__init__.py +1 -0
  9. onnxtr/models/detection/_utils/base.py +66 -0
  10. onnxtr/models/detection/models/differentiable_binarization.py +27 -12
  11. onnxtr/models/detection/models/fast.py +30 -9
  12. onnxtr/models/detection/models/linknet.py +24 -9
  13. onnxtr/models/detection/postprocessor/base.py +4 -3
  14. onnxtr/models/detection/predictor/base.py +15 -1
  15. onnxtr/models/detection/zoo.py +12 -3
  16. onnxtr/models/engine.py +73 -7
  17. onnxtr/models/predictor/base.py +65 -42
  18. onnxtr/models/predictor/predictor.py +22 -15
  19. onnxtr/models/recognition/models/crnn.py +24 -9
  20. onnxtr/models/recognition/models/master.py +14 -5
  21. onnxtr/models/recognition/models/parseq.py +14 -5
  22. onnxtr/models/recognition/models/sar.py +12 -5
  23. onnxtr/models/recognition/models/vitstr.py +18 -7
  24. onnxtr/models/recognition/zoo.py +9 -6
  25. onnxtr/models/zoo.py +16 -0
  26. onnxtr/py.typed +0 -0
  27. onnxtr/utils/geometry.py +33 -12
  28. onnxtr/version.py +1 -1
  29. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/METADATA +60 -21
  30. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/RECORD +34 -31
  31. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/WHEEL +1 -1
  32. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/top_level.txt +0 -1
  33. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/LICENSE +0 -0
  34. {onnxtr-0.2.0.dist-info → onnxtr-0.3.0.dist-info}/zip-safe +0 -0
@@ -8,10 +8,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
8
8
  import numpy as np
9
9
 
10
10
  from onnxtr.models.builder import DocumentBuilder
11
- from onnxtr.utils.geometry import extract_crops, extract_rcrops
11
+ from onnxtr.models.engine import EngineConfig
12
+ from onnxtr.utils.geometry import extract_crops, extract_rcrops, rotate_image
12
13
 
13
- from .._utils import rectify_crops, rectify_loc_preds
14
- from ..classification import crop_orientation_predictor
14
+ from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
15
+ from ..classification import crop_orientation_predictor, page_orientation_predictor
15
16
  from ..classification.predictor import OrientationPredictor
16
17
  from ..detection.zoo import ARCHS as DETECTION_ARCHS
17
18
  from ..recognition.zoo import ARCHS as RECOGNITION_ARCHS
@@ -31,11 +32,15 @@ class _OCRPredictor:
31
32
  accordingly. Doing so will improve performances for documents with page-uniform rotations.
32
33
  preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
33
34
  symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
35
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
36
+ page. Doing so will slightly deteriorate the overall latency.
34
37
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
38
+ clf_engine_cfg: configuration of the orientation classification engine
35
39
  **kwargs: keyword args of `DocumentBuilder`
36
40
  """
37
41
 
38
42
  crop_orientation_predictor: Optional[OrientationPredictor]
43
+ page_orientation_predictor: Optional[OrientationPredictor]
39
44
 
40
45
  def __init__(
41
46
  self,
@@ -43,19 +48,75 @@ class _OCRPredictor:
43
48
  straighten_pages: bool = False,
44
49
  preserve_aspect_ratio: bool = True,
45
50
  symmetric_pad: bool = True,
51
+ detect_orientation: bool = False,
46
52
  load_in_8_bit: bool = False,
53
+ clf_engine_cfg: EngineConfig = EngineConfig(),
47
54
  **kwargs: Any,
48
55
  ) -> None:
49
56
  self.assume_straight_pages = assume_straight_pages
50
57
  self.straighten_pages = straighten_pages
51
58
  self.crop_orientation_predictor = (
52
- None if assume_straight_pages else crop_orientation_predictor(load_in_8_bit=load_in_8_bit)
59
+ None
60
+ if assume_straight_pages
61
+ else crop_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
62
+ )
63
+ self.page_orientation_predictor = (
64
+ page_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
65
+ if detect_orientation or straighten_pages or not assume_straight_pages
66
+ else None
53
67
  )
54
68
  self.doc_builder = DocumentBuilder(**kwargs)
55
69
  self.preserve_aspect_ratio = preserve_aspect_ratio
56
70
  self.symmetric_pad = symmetric_pad
57
71
  self.hooks: List[Callable] = []
58
72
 
73
+ def _general_page_orientations(
74
+ self,
75
+ pages: List[np.ndarray],
76
+ ) -> List[Tuple[int, float]]:
77
+ _, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
78
+ # Flatten to list of tuples with (value, confidence)
79
+ page_orientations = [
80
+ (orientation, prob)
81
+ for page_classes, page_probs in zip(classes, probs)
82
+ for orientation, prob in zip(page_classes, page_probs)
83
+ ]
84
+ return page_orientations
85
+
86
+ def _get_orientations(
87
+ self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
88
+ ) -> Tuple[List[Tuple[int, float]], List[int]]:
89
+ general_pages_orientations = self._general_page_orientations(pages)
90
+ origin_page_orientations = [
91
+ estimate_orientation(seq_map, general_orientation)
92
+ for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
93
+ ]
94
+ return general_pages_orientations, origin_page_orientations
95
+
96
+ def _straighten_pages(
97
+ self,
98
+ pages: List[np.ndarray],
99
+ seg_maps: List[np.ndarray],
100
+ general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
101
+ origin_pages_orientations: Optional[List[int]] = None,
102
+ ) -> List[np.ndarray]:
103
+ general_pages_orientations = (
104
+ general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
105
+ )
106
+ origin_pages_orientations = (
107
+ origin_pages_orientations
108
+ if origin_pages_orientations
109
+ else [
110
+ estimate_orientation(seq_map, general_orientation)
111
+ for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
112
+ ]
113
+ )
114
+ return [
115
+ # We exapnd if the page is wider than tall and the angle is 90 or -90
116
+ rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90)
117
+ for page, angle in zip(pages, origin_pages_orientations)
118
+ ]
119
+
59
120
  @staticmethod
60
121
  def _generate_crops(
61
122
  pages: List[np.ndarray],
@@ -110,44 +171,6 @@ class _OCRPredictor:
110
171
  ]
111
172
  return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
112
173
 
113
- def _remove_padding(
114
- self,
115
- pages: List[np.ndarray],
116
- loc_preds: List[np.ndarray],
117
- ) -> List[np.ndarray]:
118
- if self.preserve_aspect_ratio:
119
- # Rectify loc_preds to remove padding
120
- rectified_preds = []
121
- for page, loc_pred in zip(pages, loc_preds):
122
- h, w = page.shape[0], page.shape[1]
123
- if h > w:
124
- # y unchanged, dilate x coord
125
- if self.symmetric_pad:
126
- if self.assume_straight_pages:
127
- loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
128
- else:
129
- loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
130
- else:
131
- if self.assume_straight_pages:
132
- loc_pred[:, [0, 2]] *= h / w
133
- else:
134
- loc_pred[:, :, 0] *= h / w
135
- elif w > h:
136
- # x unchanged, dilate y coord
137
- if self.symmetric_pad:
138
- if self.assume_straight_pages:
139
- loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
140
- else:
141
- loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
142
- else:
143
- if self.assume_straight_pages:
144
- loc_pred[:, [1, 3]] *= w / h
145
- else:
146
- loc_pred[:, :, 1] *= w / h
147
- rectified_preds.append(loc_pred)
148
- return rectified_preds
149
- return loc_preds
150
-
151
174
  @staticmethod
152
175
  def _process_predictions(
153
176
  loc_preds: List[np.ndarray],
@@ -8,10 +8,11 @@ from typing import Any, List
8
8
  import numpy as np
9
9
 
10
10
  from onnxtr.io.elements import Document
11
- from onnxtr.models._utils import estimate_orientation, get_language
11
+ from onnxtr.models._utils import get_language
12
12
  from onnxtr.models.detection.predictor import DetectionPredictor
13
+ from onnxtr.models.engine import EngineConfig
13
14
  from onnxtr.models.recognition.predictor import RecognitionPredictor
14
- from onnxtr.utils.geometry import rotate_image
15
+ from onnxtr.utils.geometry import detach_scores
15
16
  from onnxtr.utils.repr import NestedObject
16
17
 
17
18
  from .base import _OCRPredictor
@@ -35,6 +36,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
35
36
  page. Doing so will slightly deteriorate the overall latency.
36
37
  detect_language: if True, the language prediction will be added to the predictions for each
37
38
  page. Doing so will slightly deteriorate the overall latency.
39
+ clf_engine_cfg: configuration of the orientation classification engine
38
40
  **kwargs: keyword args of `DocumentBuilder`
39
41
  """
40
42
 
@@ -50,12 +52,20 @@ class OCRPredictor(NestedObject, _OCRPredictor):
50
52
  symmetric_pad: bool = True,
51
53
  detect_orientation: bool = False,
52
54
  detect_language: bool = False,
55
+ clf_engine_cfg: EngineConfig = EngineConfig(),
53
56
  **kwargs: Any,
54
57
  ) -> None:
55
58
  self.det_predictor = det_predictor
56
59
  self.reco_predictor = reco_predictor
57
60
  _OCRPredictor.__init__(
58
- self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
61
+ self,
62
+ assume_straight_pages,
63
+ straighten_pages,
64
+ preserve_aspect_ratio,
65
+ symmetric_pad,
66
+ detect_orientation,
67
+ clf_engine_cfg=clf_engine_cfg,
68
+ **kwargs,
59
69
  )
60
70
  self.detect_orientation = detect_orientation
61
71
  self.detect_language = detect_language
@@ -80,26 +90,22 @@ class OCRPredictor(NestedObject, _OCRPredictor):
80
90
  for out_map in out_maps
81
91
  ]
82
92
  if self.detect_orientation:
83
- origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
93
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
84
94
  orientations = [
85
- {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
95
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
86
96
  ]
87
97
  else:
88
98
  orientations = None
99
+ general_pages_orientations = None
100
+ origin_pages_orientations = None
89
101
  if self.straighten_pages:
90
- origin_page_orientations = (
91
- origin_page_orientations
92
- if self.detect_orientation
93
- else [estimate_orientation(seq_map) for seq_map in seg_maps]
94
- )
95
- pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
102
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
103
+
96
104
  # forward again to get predictions on straight pages
97
105
  loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
98
106
 
99
- loc_preds = [loc_pred[0] for loc_pred in loc_preds]
100
-
101
- # Rectify crops if aspect ratio
102
- loc_preds = self._remove_padding(pages, loc_preds)
107
+ # Detach objectness scores from loc_preds
108
+ loc_preds, objectness_scores = detach_scores(loc_preds) # type: ignore[arg-type]
103
109
 
104
110
  # Apply hooks to loc_preds if any
105
111
  for hook in self.hooks:
@@ -136,6 +142,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
136
142
  out = self.doc_builder(
137
143
  pages,
138
144
  boxes,
145
+ objectness_scores,
139
146
  text_preds,
140
147
  origin_page_shapes, # type: ignore[arg-type]
141
148
  crop_orientations,
@@ -12,7 +12,7 @@ from scipy.special import softmax
12
12
 
13
13
  from onnxtr.utils import VOCABS
14
14
 
15
- from ...engine import Engine
15
+ from ...engine import Engine, EngineConfig
16
16
  from ..core import RecognitionPostProcessor
17
17
 
18
18
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
@@ -113,6 +113,7 @@ class CRNN(Engine):
113
113
  ----
114
114
  model_path: path or url to onnx model file
115
115
  vocab: vocabulary used for encoding
116
+ engine_cfg: configuration for the inference engine
116
117
  cfg: configuration dictionary
117
118
  **kwargs: additional arguments to be passed to `Engine`
118
119
  """
@@ -123,10 +124,11 @@ class CRNN(Engine):
123
124
  self,
124
125
  model_path: str,
125
126
  vocab: str,
127
+ engine_cfg: EngineConfig = EngineConfig(),
126
128
  cfg: Optional[Dict[str, Any]] = None,
127
129
  **kwargs: Any,
128
130
  ) -> None:
129
- super().__init__(url=model_path, **kwargs)
131
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
130
132
  self.vocab = vocab
131
133
  self.cfg = cfg
132
134
  self.postprocessor = CRNNPostProcessor(self.vocab)
@@ -152,6 +154,7 @@ def _crnn(
152
154
  arch: str,
153
155
  model_path: str,
154
156
  load_in_8_bit: bool = False,
157
+ engine_cfg: EngineConfig = EngineConfig(),
155
158
  **kwargs: Any,
156
159
  ) -> CRNN:
157
160
  kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
@@ -163,11 +166,14 @@ def _crnn(
163
166
  model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
164
167
 
165
168
  # Build the model
166
- return CRNN(model_path, cfg=_cfg, **kwargs)
169
+ return CRNN(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
167
170
 
168
171
 
169
172
  def crnn_vgg16_bn(
170
- model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], load_in_8_bit: bool = False, **kwargs: Any
173
+ model_path: str = default_cfgs["crnn_vgg16_bn"]["url"],
174
+ load_in_8_bit: bool = False,
175
+ engine_cfg: EngineConfig = EngineConfig(),
176
+ **kwargs: Any,
171
177
  ) -> CRNN:
172
178
  """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
173
179
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
@@ -182,17 +188,21 @@ def crnn_vgg16_bn(
182
188
  ----
183
189
  model_path: path to onnx model file, defaults to url in default_cfgs
184
190
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
191
+ engine_cfg: configuration for the inference engine
185
192
  **kwargs: keyword arguments of the CRNN architecture
186
193
 
187
194
  Returns:
188
195
  -------
189
196
  text recognition architecture
190
197
  """
191
- return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, **kwargs)
198
+ return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, engine_cfg, **kwargs)
192
199
 
193
200
 
194
201
  def crnn_mobilenet_v3_small(
195
- model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any
202
+ model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"],
203
+ load_in_8_bit: bool = False,
204
+ engine_cfg: EngineConfig = EngineConfig(),
205
+ **kwargs: Any,
196
206
  ) -> CRNN:
197
207
  """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
198
208
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
@@ -207,17 +217,21 @@ def crnn_mobilenet_v3_small(
207
217
  ----
208
218
  model_path: path to onnx model file, defaults to url in default_cfgs
209
219
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
220
+ engine_cfg: configuration for the inference engine
210
221
  **kwargs: keyword arguments of the CRNN architecture
211
222
 
212
223
  Returns:
213
224
  -------
214
225
  text recognition architecture
215
226
  """
216
- return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, **kwargs)
227
+ return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
217
228
 
218
229
 
219
230
  def crnn_mobilenet_v3_large(
220
- model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any
231
+ model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"],
232
+ load_in_8_bit: bool = False,
233
+ engine_cfg: EngineConfig = EngineConfig(),
234
+ **kwargs: Any,
221
235
  ) -> CRNN:
222
236
  """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
223
237
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
@@ -232,10 +246,11 @@ def crnn_mobilenet_v3_large(
232
246
  ----
233
247
  model_path: path to onnx model file, defaults to url in default_cfgs
234
248
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
249
+ engine_cfg: configuration for the inference engine
235
250
  **kwargs: keyword arguments of the CRNN architecture
236
251
 
237
252
  Returns:
238
253
  -------
239
254
  text recognition architecture
240
255
  """
241
- return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
256
+ return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -11,7 +11,7 @@ from scipy.special import softmax
11
11
 
12
12
  from onnxtr.utils import VOCABS
13
13
 
14
- from ...engine import Engine
14
+ from ...engine import Engine, EngineConfig
15
15
  from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["MASTER", "master"]
@@ -36,6 +36,7 @@ class MASTER(Engine):
36
36
  ----
37
37
  model_path: path or url to onnx model file
38
38
  vocab: vocabulary, (without EOS, SOS, PAD)
39
+ engine_cfg: configuration for the inference engine
39
40
  cfg: dictionary containing information about the model
40
41
  **kwargs: additional arguments to be passed to `Engine`
41
42
  """
@@ -44,10 +45,11 @@ class MASTER(Engine):
44
45
  self,
45
46
  model_path: str,
46
47
  vocab: str,
48
+ engine_cfg: EngineConfig = EngineConfig(),
47
49
  cfg: Optional[Dict[str, Any]] = None,
48
50
  **kwargs: Any,
49
51
  ) -> None:
50
- super().__init__(url=model_path, **kwargs)
52
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
51
53
 
52
54
  self.vocab = vocab
53
55
  self.cfg = cfg
@@ -114,6 +116,7 @@ def _master(
114
116
  arch: str,
115
117
  model_path: str,
116
118
  load_in_8_bit: bool = False,
119
+ engine_cfg: EngineConfig = EngineConfig(),
117
120
  **kwargs: Any,
118
121
  ) -> MASTER:
119
122
  # Patch the config
@@ -125,10 +128,15 @@ def _master(
125
128
  # Patch the url
126
129
  model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
127
130
 
128
- return MASTER(model_path, cfg=_cfg, **kwargs)
131
+ return MASTER(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
129
132
 
130
133
 
131
- def master(model_path: str = default_cfgs["master"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> MASTER:
134
+ def master(
135
+ model_path: str = default_cfgs["master"]["url"],
136
+ load_in_8_bit: bool = False,
137
+ engine_cfg: EngineConfig = EngineConfig(),
138
+ **kwargs: Any,
139
+ ) -> MASTER:
132
140
  """MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
133
141
 
134
142
  >>> import numpy as np
@@ -141,10 +149,11 @@ def master(model_path: str = default_cfgs["master"]["url"], load_in_8_bit: bool
141
149
  ----
142
150
  model_path: path to onnx model file, defaults to url in default_cfgs
143
151
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
152
+ engine_cfg: configuration for the inference engine
144
153
  **kwargs: keywoard arguments passed to the MASTER architecture
145
154
 
146
155
  Returns:
147
156
  -------
148
157
  text recognition architecture
149
158
  """
150
- return _master("master", model_path, load_in_8_bit, **kwargs)
159
+ return _master("master", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -11,7 +11,7 @@ from scipy.special import softmax
11
11
 
12
12
  from onnxtr.utils import VOCABS
13
13
 
14
- from ...engine import Engine
14
+ from ...engine import Engine, EngineConfig
15
15
  from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["PARSeq", "parseq"]
@@ -35,6 +35,7 @@ class PARSeq(Engine):
35
35
  ----
36
36
  model_path: path to onnx model file
37
37
  vocab: vocabulary used for encoding
38
+ engine_cfg: configuration for the inference engine
38
39
  cfg: dictionary containing information about the model
39
40
  **kwargs: additional arguments to be passed to `Engine`
40
41
  """
@@ -43,10 +44,11 @@ class PARSeq(Engine):
43
44
  self,
44
45
  model_path: str,
45
46
  vocab: str,
47
+ engine_cfg: EngineConfig = EngineConfig(),
46
48
  cfg: Optional[Dict[str, Any]] = None,
47
49
  **kwargs: Any,
48
50
  ) -> None:
49
- super().__init__(url=model_path, **kwargs)
51
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
50
52
  self.vocab = vocab
51
53
  self.cfg = cfg
52
54
  self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
@@ -102,6 +104,7 @@ def _parseq(
102
104
  arch: str,
103
105
  model_path: str,
104
106
  load_in_8_bit: bool = False,
107
+ engine_cfg: EngineConfig = EngineConfig(),
105
108
  **kwargs: Any,
106
109
  ) -> PARSeq:
107
110
  # Patch the config
@@ -114,10 +117,15 @@ def _parseq(
114
117
  model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
115
118
 
116
119
  # Build the model
117
- return PARSeq(model_path, cfg=_cfg, **kwargs)
120
+ return PARSeq(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
118
121
 
119
122
 
120
- def parseq(model_path: str = default_cfgs["parseq"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> PARSeq:
123
+ def parseq(
124
+ model_path: str = default_cfgs["parseq"]["url"],
125
+ load_in_8_bit: bool = False,
126
+ engine_cfg: EngineConfig = EngineConfig(),
127
+ **kwargs: Any,
128
+ ) -> PARSeq:
121
129
  """PARSeq architecture from
122
130
  `"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
123
131
 
@@ -131,10 +139,11 @@ def parseq(model_path: str = default_cfgs["parseq"]["url"], load_in_8_bit: bool
131
139
  ----
132
140
  model_path: path to onnx model file, defaults to url in default_cfgs
133
141
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
142
+ engine_cfg: configuration for the inference engine
134
143
  **kwargs: keyword arguments of the PARSeq architecture
135
144
 
136
145
  Returns:
137
146
  -------
138
147
  text recognition architecture
139
148
  """
140
- return _parseq("parseq", model_path, load_in_8_bit, **kwargs)
149
+ return _parseq("parseq", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -11,7 +11,7 @@ from scipy.special import softmax
11
11
 
12
12
  from onnxtr.utils import VOCABS
13
13
 
14
- from ...engine import Engine
14
+ from ...engine import Engine, EngineConfig
15
15
  from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["SAR", "sar_resnet31"]
@@ -35,6 +35,7 @@ class SAR(Engine):
35
35
  ----
36
36
  model_path: path to onnx model file
37
37
  vocab: vocabulary used for encoding
38
+ engine_cfg: configuration for the inference engine
38
39
  cfg: dictionary containing information about the model
39
40
  **kwargs: additional arguments to be passed to `Engine`
40
41
  """
@@ -43,10 +44,11 @@ class SAR(Engine):
43
44
  self,
44
45
  model_path: str,
45
46
  vocab: str,
47
+ engine_cfg: EngineConfig = EngineConfig(),
46
48
  cfg: Optional[Dict[str, Any]] = None,
47
49
  **kwargs: Any,
48
50
  ) -> None:
49
- super().__init__(url=model_path, **kwargs)
51
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
50
52
  self.vocab = vocab
51
53
  self.cfg = cfg
52
54
  self.postprocessor = SARPostProcessor(self.vocab)
@@ -101,6 +103,7 @@ def _sar(
101
103
  arch: str,
102
104
  model_path: str,
103
105
  load_in_8_bit: bool = False,
106
+ engine_cfg: EngineConfig = EngineConfig(),
104
107
  **kwargs: Any,
105
108
  ) -> SAR:
106
109
  # Patch the config
@@ -113,11 +116,14 @@ def _sar(
113
116
  model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
114
117
 
115
118
  # Build the model
116
- return SAR(model_path, cfg=_cfg, **kwargs)
119
+ return SAR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
117
120
 
118
121
 
119
122
  def sar_resnet31(
120
- model_path: str = default_cfgs["sar_resnet31"]["url"], load_in_8_bit: bool = False, **kwargs: Any
123
+ model_path: str = default_cfgs["sar_resnet31"]["url"],
124
+ load_in_8_bit: bool = False,
125
+ engine_cfg: EngineConfig = EngineConfig(),
126
+ **kwargs: Any,
121
127
  ) -> SAR:
122
128
  """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
123
129
  Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
@@ -132,10 +138,11 @@ def sar_resnet31(
132
138
  ----
133
139
  model_path: path to onnx model file, defaults to url in default_cfgs
134
140
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
141
+ engine_cfg: configuration for the inference engine
135
142
  **kwargs: keyword arguments of the SAR architecture
136
143
 
137
144
  Returns:
138
145
  -------
139
146
  text recognition architecture
140
147
  """
141
- return _sar("sar_resnet31", model_path, load_in_8_bit, **kwargs)
148
+ return _sar("sar_resnet31", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -11,7 +11,7 @@ from scipy.special import softmax
11
11
 
12
12
  from onnxtr.utils import VOCABS
13
13
 
14
- from ...engine import Engine
14
+ from ...engine import Engine, EngineConfig
15
15
  from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
@@ -43,6 +43,7 @@ class ViTSTR(Engine):
43
43
  ----
44
44
  model_path: path to onnx model file
45
45
  vocab: vocabulary used for encoding
46
+ engine_cfg: configuration for the inference engine
46
47
  cfg: dictionary containing information about the model
47
48
  **kwargs: additional arguments to be passed to `Engine`
48
49
  """
@@ -51,10 +52,11 @@ class ViTSTR(Engine):
51
52
  self,
52
53
  model_path: str,
53
54
  vocab: str,
55
+ engine_cfg: EngineConfig = EngineConfig(),
54
56
  cfg: Optional[Dict[str, Any]] = None,
55
57
  **kwargs: Any,
56
58
  ) -> None:
57
- super().__init__(url=model_path, **kwargs)
59
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
58
60
  self.vocab = vocab
59
61
  self.cfg = cfg
60
62
 
@@ -112,6 +114,7 @@ def _vitstr(
112
114
  arch: str,
113
115
  model_path: str,
114
116
  load_in_8_bit: bool = False,
117
+ engine_cfg: EngineConfig = EngineConfig(),
115
118
  **kwargs: Any,
116
119
  ) -> ViTSTR:
117
120
  # Patch the config
@@ -124,11 +127,14 @@ def _vitstr(
124
127
  model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
125
128
 
126
129
  # Build the model
127
- return ViTSTR(model_path, cfg=_cfg, **kwargs)
130
+ return ViTSTR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
128
131
 
129
132
 
130
133
  def vitstr_small(
131
- model_path: str = default_cfgs["vitstr_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any
134
+ model_path: str = default_cfgs["vitstr_small"]["url"],
135
+ load_in_8_bit: bool = False,
136
+ engine_cfg: EngineConfig = EngineConfig(),
137
+ **kwargs: Any,
132
138
  ) -> ViTSTR:
133
139
  """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
134
140
  <https://arxiv.org/pdf/2105.08582.pdf>`_.
@@ -143,17 +149,21 @@ def vitstr_small(
143
149
  ----
144
150
  model_path: path to onnx model file, defaults to url in default_cfgs
145
151
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
152
+ engine_cfg: configuration for the inference engine
146
153
  **kwargs: keyword arguments of the ViTSTR architecture
147
154
 
148
155
  Returns:
149
156
  -------
150
157
  text recognition architecture
151
158
  """
152
- return _vitstr("vitstr_small", model_path, load_in_8_bit, **kwargs)
159
+ return _vitstr("vitstr_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
153
160
 
154
161
 
155
162
  def vitstr_base(
156
- model_path: str = default_cfgs["vitstr_base"]["url"], load_in_8_bit: bool = False, **kwargs: Any
163
+ model_path: str = default_cfgs["vitstr_base"]["url"],
164
+ load_in_8_bit: bool = False,
165
+ engine_cfg: EngineConfig = EngineConfig(),
166
+ **kwargs: Any,
157
167
  ) -> ViTSTR:
158
168
  """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
159
169
  <https://arxiv.org/pdf/2105.08582.pdf>`_.
@@ -168,10 +178,11 @@ def vitstr_base(
168
178
  ----
169
179
  model_path: path to onnx model file, defaults to url in default_cfgs
170
180
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
181
+ engine_cfg: configuration for the inference engine
171
182
  **kwargs: keyword arguments of the ViTSTR architecture
172
183
 
173
184
  Returns:
174
185
  -------
175
186
  text recognition architecture
176
187
  """
177
- return _vitstr("vitstr_base", model_path, load_in_8_bit, **kwargs)
188
+ return _vitstr("vitstr_base", model_path, load_in_8_bit, engine_cfg, **kwargs)