onnxtr 0.1.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. onnxtr/io/elements.py +17 -4
  2. onnxtr/io/pdf.py +6 -3
  3. onnxtr/models/__init__.py +1 -0
  4. onnxtr/models/_utils.py +57 -20
  5. onnxtr/models/builder.py +24 -9
  6. onnxtr/models/classification/models/mobilenet.py +25 -7
  7. onnxtr/models/classification/predictor/base.py +1 -0
  8. onnxtr/models/classification/zoo.py +22 -7
  9. onnxtr/models/detection/_utils/__init__.py +1 -0
  10. onnxtr/models/detection/_utils/base.py +66 -0
  11. onnxtr/models/detection/models/differentiable_binarization.py +41 -11
  12. onnxtr/models/detection/models/fast.py +37 -9
  13. onnxtr/models/detection/models/linknet.py +39 -9
  14. onnxtr/models/detection/postprocessor/base.py +4 -3
  15. onnxtr/models/detection/predictor/base.py +15 -1
  16. onnxtr/models/detection/zoo.py +16 -3
  17. onnxtr/models/engine.py +75 -9
  18. onnxtr/models/predictor/base.py +69 -42
  19. onnxtr/models/predictor/predictor.py +22 -15
  20. onnxtr/models/recognition/models/crnn.py +39 -9
  21. onnxtr/models/recognition/models/master.py +19 -5
  22. onnxtr/models/recognition/models/parseq.py +20 -5
  23. onnxtr/models/recognition/models/sar.py +19 -5
  24. onnxtr/models/recognition/models/vitstr.py +31 -9
  25. onnxtr/models/recognition/zoo.py +12 -6
  26. onnxtr/models/zoo.py +22 -0
  27. onnxtr/py.typed +0 -0
  28. onnxtr/utils/geometry.py +33 -12
  29. onnxtr/version.py +1 -1
  30. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/METADATA +81 -16
  31. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/RECORD +35 -32
  32. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/WHEEL +1 -1
  33. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/top_level.txt +0 -1
  34. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/LICENSE +0 -0
  35. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/zip-safe +0 -0
@@ -8,10 +8,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
8
8
  import numpy as np
9
9
 
10
10
  from onnxtr.models.builder import DocumentBuilder
11
- from onnxtr.utils.geometry import extract_crops, extract_rcrops
11
+ from onnxtr.models.engine import EngineConfig
12
+ from onnxtr.utils.geometry import extract_crops, extract_rcrops, rotate_image
12
13
 
13
- from .._utils import rectify_crops, rectify_loc_preds
14
- from ..classification import crop_orientation_predictor
14
+ from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
15
+ from ..classification import crop_orientation_predictor, page_orientation_predictor
15
16
  from ..classification.predictor import OrientationPredictor
16
17
  from ..detection.zoo import ARCHS as DETECTION_ARCHS
17
18
  from ..recognition.zoo import ARCHS as RECOGNITION_ARCHS
@@ -31,10 +32,15 @@ class _OCRPredictor:
31
32
  accordingly. Doing so will improve performances for documents with page-uniform rotations.
32
33
  preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
33
34
  symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
35
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
36
+ page. Doing so will slightly deteriorate the overall latency.
37
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
38
+ clf_engine_cfg: configuration of the orientation classification engine
34
39
  **kwargs: keyword args of `DocumentBuilder`
35
40
  """
36
41
 
37
42
  crop_orientation_predictor: Optional[OrientationPredictor]
43
+ page_orientation_predictor: Optional[OrientationPredictor]
38
44
 
39
45
  def __init__(
40
46
  self,
@@ -42,16 +48,75 @@ class _OCRPredictor:
42
48
  straighten_pages: bool = False,
43
49
  preserve_aspect_ratio: bool = True,
44
50
  symmetric_pad: bool = True,
51
+ detect_orientation: bool = False,
52
+ load_in_8_bit: bool = False,
53
+ clf_engine_cfg: EngineConfig = EngineConfig(),
45
54
  **kwargs: Any,
46
55
  ) -> None:
47
56
  self.assume_straight_pages = assume_straight_pages
48
57
  self.straighten_pages = straighten_pages
49
- self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor()
58
+ self.crop_orientation_predictor = (
59
+ None
60
+ if assume_straight_pages
61
+ else crop_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
62
+ )
63
+ self.page_orientation_predictor = (
64
+ page_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
65
+ if detect_orientation or straighten_pages or not assume_straight_pages
66
+ else None
67
+ )
50
68
  self.doc_builder = DocumentBuilder(**kwargs)
51
69
  self.preserve_aspect_ratio = preserve_aspect_ratio
52
70
  self.symmetric_pad = symmetric_pad
53
71
  self.hooks: List[Callable] = []
54
72
 
73
+ def _general_page_orientations(
74
+ self,
75
+ pages: List[np.ndarray],
76
+ ) -> List[Tuple[int, float]]:
77
+ _, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
78
+ # Flatten to list of tuples with (value, confidence)
79
+ page_orientations = [
80
+ (orientation, prob)
81
+ for page_classes, page_probs in zip(classes, probs)
82
+ for orientation, prob in zip(page_classes, page_probs)
83
+ ]
84
+ return page_orientations
85
+
86
+ def _get_orientations(
87
+ self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
88
+ ) -> Tuple[List[Tuple[int, float]], List[int]]:
89
+ general_pages_orientations = self._general_page_orientations(pages)
90
+ origin_page_orientations = [
91
+ estimate_orientation(seq_map, general_orientation)
92
+ for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
93
+ ]
94
+ return general_pages_orientations, origin_page_orientations
95
+
96
+ def _straighten_pages(
97
+ self,
98
+ pages: List[np.ndarray],
99
+ seg_maps: List[np.ndarray],
100
+ general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
101
+ origin_pages_orientations: Optional[List[int]] = None,
102
+ ) -> List[np.ndarray]:
103
+ general_pages_orientations = (
104
+ general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
105
+ )
106
+ origin_pages_orientations = (
107
+ origin_pages_orientations
108
+ if origin_pages_orientations
109
+ else [
110
+ estimate_orientation(seq_map, general_orientation)
111
+ for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
112
+ ]
113
+ )
114
+ return [
115
+ # We exapnd if the page is wider than tall and the angle is 90 or -90
116
+ rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90)
117
+ for page, angle in zip(pages, origin_pages_orientations)
118
+ ]
119
+
55
120
  @staticmethod
56
121
  def _generate_crops(
57
122
  pages: List[np.ndarray],
@@ -106,44 +171,6 @@ class _OCRPredictor:
106
171
  ]
107
172
  return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
108
173
 
109
- def _remove_padding(
110
- self,
111
- pages: List[np.ndarray],
112
- loc_preds: List[np.ndarray],
113
- ) -> List[np.ndarray]:
114
- if self.preserve_aspect_ratio:
115
- # Rectify loc_preds to remove padding
116
- rectified_preds = []
117
- for page, loc_pred in zip(pages, loc_preds):
118
- h, w = page.shape[0], page.shape[1]
119
- if h > w:
120
- # y unchanged, dilate x coord
121
- if self.symmetric_pad:
122
- if self.assume_straight_pages:
123
- loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
124
- else:
125
- loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
126
- else:
127
- if self.assume_straight_pages:
128
- loc_pred[:, [0, 2]] *= h / w
129
- else:
130
- loc_pred[:, :, 0] *= h / w
131
- elif w > h:
132
- # x unchanged, dilate y coord
133
- if self.symmetric_pad:
134
- if self.assume_straight_pages:
135
- loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
136
- else:
137
- loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
138
- else:
139
- if self.assume_straight_pages:
140
- loc_pred[:, [1, 3]] *= w / h
141
- else:
142
- loc_pred[:, :, 1] *= w / h
143
- rectified_preds.append(loc_pred)
144
- return rectified_preds
145
- return loc_preds
146
-
147
174
  @staticmethod
148
175
  def _process_predictions(
149
176
  loc_preds: List[np.ndarray],
@@ -8,10 +8,11 @@ from typing import Any, List
8
8
  import numpy as np
9
9
 
10
10
  from onnxtr.io.elements import Document
11
- from onnxtr.models._utils import estimate_orientation, get_language
11
+ from onnxtr.models._utils import get_language
12
12
  from onnxtr.models.detection.predictor import DetectionPredictor
13
+ from onnxtr.models.engine import EngineConfig
13
14
  from onnxtr.models.recognition.predictor import RecognitionPredictor
14
- from onnxtr.utils.geometry import rotate_image
15
+ from onnxtr.utils.geometry import detach_scores
15
16
  from onnxtr.utils.repr import NestedObject
16
17
 
17
18
  from .base import _OCRPredictor
@@ -35,6 +36,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
35
36
  page. Doing so will slightly deteriorate the overall latency.
36
37
  detect_language: if True, the language prediction will be added to the predictions for each
37
38
  page. Doing so will slightly deteriorate the overall latency.
39
+ clf_engine_cfg: configuration of the orientation classification engine
38
40
  **kwargs: keyword args of `DocumentBuilder`
39
41
  """
40
42
 
@@ -50,12 +52,20 @@ class OCRPredictor(NestedObject, _OCRPredictor):
50
52
  symmetric_pad: bool = True,
51
53
  detect_orientation: bool = False,
52
54
  detect_language: bool = False,
55
+ clf_engine_cfg: EngineConfig = EngineConfig(),
53
56
  **kwargs: Any,
54
57
  ) -> None:
55
58
  self.det_predictor = det_predictor
56
59
  self.reco_predictor = reco_predictor
57
60
  _OCRPredictor.__init__(
58
- self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
61
+ self,
62
+ assume_straight_pages,
63
+ straighten_pages,
64
+ preserve_aspect_ratio,
65
+ symmetric_pad,
66
+ detect_orientation,
67
+ clf_engine_cfg=clf_engine_cfg,
68
+ **kwargs,
59
69
  )
60
70
  self.detect_orientation = detect_orientation
61
71
  self.detect_language = detect_language
@@ -80,26 +90,22 @@ class OCRPredictor(NestedObject, _OCRPredictor):
80
90
  for out_map in out_maps
81
91
  ]
82
92
  if self.detect_orientation:
83
- origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
93
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
84
94
  orientations = [
85
- {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
95
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
86
96
  ]
87
97
  else:
88
98
  orientations = None
99
+ general_pages_orientations = None
100
+ origin_pages_orientations = None
89
101
  if self.straighten_pages:
90
- origin_page_orientations = (
91
- origin_page_orientations
92
- if self.detect_orientation
93
- else [estimate_orientation(seq_map) for seq_map in seg_maps]
94
- )
95
- pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
102
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
103
+
96
104
  # forward again to get predictions on straight pages
97
105
  loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
98
106
 
99
- loc_preds = [loc_pred[0] for loc_pred in loc_preds]
100
-
101
- # Rectify crops if aspect ratio
102
- loc_preds = self._remove_padding(pages, loc_preds)
107
+ # Detach objectness scores from loc_preds
108
+ loc_preds, objectness_scores = detach_scores(loc_preds) # type: ignore[arg-type]
103
109
 
104
110
  # Apply hooks to loc_preds if any
105
111
  for hook in self.hooks:
@@ -136,6 +142,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
136
142
  out = self.doc_builder(
137
143
  pages,
138
144
  boxes,
145
+ objectness_scores,
139
146
  text_preds,
140
147
  origin_page_shapes, # type: ignore[arg-type]
141
148
  crop_orientations,
@@ -12,7 +12,7 @@ from scipy.special import softmax
12
12
 
13
13
  from onnxtr.utils import VOCABS
14
14
 
15
- from ...engine import Engine
15
+ from ...engine import Engine, EngineConfig
16
16
  from ..core import RecognitionPostProcessor
17
17
 
18
18
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
@@ -24,6 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
24
24
  "input_shape": (3, 32, 128),
25
25
  "vocab": VOCABS["legacy_french"],
26
26
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_vgg16_bn-662979cc.onnx",
27
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/crnn_vgg16_bn_static_8_bit-bce050c7.onnx",
27
28
  },
28
29
  "crnn_mobilenet_v3_small": {
29
30
  "mean": (0.694, 0.695, 0.693),
@@ -31,6 +32,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
31
32
  "input_shape": (3, 32, 128),
32
33
  "vocab": VOCABS["french"],
33
34
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_mobilenet_v3_small-bded4d49.onnx",
35
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/crnn_mobilenet_v3_small_static_8_bit-4949006f.onnx",
34
36
  },
35
37
  "crnn_mobilenet_v3_large": {
36
38
  "mean": (0.694, 0.695, 0.693),
@@ -38,6 +40,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
38
40
  "input_shape": (3, 32, 128),
39
41
  "vocab": VOCABS["french"],
40
42
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/crnn_mobilenet_v3_large-d42e8185.onnx",
43
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/crnn_mobilenet_v3_large_static_8_bit-459e856d.onnx",
41
44
  },
42
45
  }
43
46
 
@@ -110,6 +113,7 @@ class CRNN(Engine):
110
113
  ----
111
114
  model_path: path or url to onnx model file
112
115
  vocab: vocabulary used for encoding
116
+ engine_cfg: configuration for the inference engine
113
117
  cfg: configuration dictionary
114
118
  **kwargs: additional arguments to be passed to `Engine`
115
119
  """
@@ -120,10 +124,11 @@ class CRNN(Engine):
120
124
  self,
121
125
  model_path: str,
122
126
  vocab: str,
127
+ engine_cfg: EngineConfig = EngineConfig(),
123
128
  cfg: Optional[Dict[str, Any]] = None,
124
129
  **kwargs: Any,
125
130
  ) -> None:
126
- super().__init__(url=model_path, **kwargs)
131
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
127
132
  self.vocab = vocab
128
133
  self.cfg = cfg
129
134
  self.postprocessor = CRNNPostProcessor(self.vocab)
@@ -148,6 +153,8 @@ class CRNN(Engine):
148
153
  def _crnn(
149
154
  arch: str,
150
155
  model_path: str,
156
+ load_in_8_bit: bool = False,
157
+ engine_cfg: EngineConfig = EngineConfig(),
151
158
  **kwargs: Any,
152
159
  ) -> CRNN:
153
160
  kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
@@ -155,12 +162,19 @@ def _crnn(
155
162
  _cfg = deepcopy(default_cfgs[arch])
156
163
  _cfg["vocab"] = kwargs["vocab"]
157
164
  _cfg["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
165
+ # Patch the url
166
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
158
167
 
159
168
  # Build the model
160
- return CRNN(model_path, cfg=_cfg, **kwargs)
169
+ return CRNN(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
161
170
 
162
171
 
163
- def crnn_vgg16_bn(model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], **kwargs: Any) -> CRNN:
172
+ def crnn_vgg16_bn(
173
+ model_path: str = default_cfgs["crnn_vgg16_bn"]["url"],
174
+ load_in_8_bit: bool = False,
175
+ engine_cfg: EngineConfig = EngineConfig(),
176
+ **kwargs: Any,
177
+ ) -> CRNN:
164
178
  """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
165
179
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
166
180
 
@@ -173,16 +187,23 @@ def crnn_vgg16_bn(model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], **kwar
173
187
  Args:
174
188
  ----
175
189
  model_path: path to onnx model file, defaults to url in default_cfgs
190
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
191
+ engine_cfg: configuration for the inference engine
176
192
  **kwargs: keyword arguments of the CRNN architecture
177
193
 
178
194
  Returns:
179
195
  -------
180
196
  text recognition architecture
181
197
  """
182
- return _crnn("crnn_vgg16_bn", model_path, **kwargs)
198
+ return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, engine_cfg, **kwargs)
183
199
 
184
200
 
185
- def crnn_mobilenet_v3_small(model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"], **kwargs: Any) -> CRNN:
201
+ def crnn_mobilenet_v3_small(
202
+ model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"],
203
+ load_in_8_bit: bool = False,
204
+ engine_cfg: EngineConfig = EngineConfig(),
205
+ **kwargs: Any,
206
+ ) -> CRNN:
186
207
  """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
187
208
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
188
209
 
@@ -195,16 +216,23 @@ def crnn_mobilenet_v3_small(model_path: str = default_cfgs["crnn_mobilenet_v3_sm
195
216
  Args:
196
217
  ----
197
218
  model_path: path to onnx model file, defaults to url in default_cfgs
219
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
220
+ engine_cfg: configuration for the inference engine
198
221
  **kwargs: keyword arguments of the CRNN architecture
199
222
 
200
223
  Returns:
201
224
  -------
202
225
  text recognition architecture
203
226
  """
204
- return _crnn("crnn_mobilenet_v3_small", model_path, **kwargs)
227
+ return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
205
228
 
206
229
 
207
- def crnn_mobilenet_v3_large(model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"], **kwargs: Any) -> CRNN:
230
+ def crnn_mobilenet_v3_large(
231
+ model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"],
232
+ load_in_8_bit: bool = False,
233
+ engine_cfg: EngineConfig = EngineConfig(),
234
+ **kwargs: Any,
235
+ ) -> CRNN:
208
236
  """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
209
237
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
210
238
 
@@ -217,10 +245,12 @@ def crnn_mobilenet_v3_large(model_path: str = default_cfgs["crnn_mobilenet_v3_la
217
245
  Args:
218
246
  ----
219
247
  model_path: path to onnx model file, defaults to url in default_cfgs
248
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
249
+ engine_cfg: configuration for the inference engine
220
250
  **kwargs: keyword arguments of the CRNN architecture
221
251
 
222
252
  Returns:
223
253
  -------
224
254
  text recognition architecture
225
255
  """
226
- return _crnn("crnn_mobilenet_v3_large", model_path, **kwargs)
256
+ return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -11,7 +11,7 @@ from scipy.special import softmax
11
11
 
12
12
  from onnxtr.utils import VOCABS
13
13
 
14
- from ...engine import Engine
14
+ from ...engine import Engine, EngineConfig
15
15
  from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["MASTER", "master"]
@@ -24,6 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
24
24
  "input_shape": (3, 32, 128),
25
25
  "vocab": VOCABS["french"],
26
26
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/master-b1287fcd.onnx",
27
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/master_dynamic_8_bit-d8bd8206.onnx",
27
28
  },
28
29
  }
29
30
 
@@ -35,6 +36,7 @@ class MASTER(Engine):
35
36
  ----
36
37
  model_path: path or url to onnx model file
37
38
  vocab: vocabulary, (without EOS, SOS, PAD)
39
+ engine_cfg: configuration for the inference engine
38
40
  cfg: dictionary containing information about the model
39
41
  **kwargs: additional arguments to be passed to `Engine`
40
42
  """
@@ -43,10 +45,11 @@ class MASTER(Engine):
43
45
  self,
44
46
  model_path: str,
45
47
  vocab: str,
48
+ engine_cfg: EngineConfig = EngineConfig(),
46
49
  cfg: Optional[Dict[str, Any]] = None,
47
50
  **kwargs: Any,
48
51
  ) -> None:
49
- super().__init__(url=model_path, **kwargs)
52
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
50
53
 
51
54
  self.vocab = vocab
52
55
  self.cfg = cfg
@@ -112,6 +115,8 @@ class MASTERPostProcessor(RecognitionPostProcessor):
112
115
  def _master(
113
116
  arch: str,
114
117
  model_path: str,
118
+ load_in_8_bit: bool = False,
119
+ engine_cfg: EngineConfig = EngineConfig(),
115
120
  **kwargs: Any,
116
121
  ) -> MASTER:
117
122
  # Patch the config
@@ -120,11 +125,18 @@ def _master(
120
125
  _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
121
126
 
122
127
  kwargs["vocab"] = _cfg["vocab"]
128
+ # Patch the url
129
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
123
130
 
124
- return MASTER(model_path, cfg=_cfg, **kwargs)
131
+ return MASTER(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
125
132
 
126
133
 
127
- def master(model_path: str = default_cfgs["master"]["url"], **kwargs: Any) -> MASTER:
134
+ def master(
135
+ model_path: str = default_cfgs["master"]["url"],
136
+ load_in_8_bit: bool = False,
137
+ engine_cfg: EngineConfig = EngineConfig(),
138
+ **kwargs: Any,
139
+ ) -> MASTER:
128
140
  """MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
129
141
 
130
142
  >>> import numpy as np
@@ -136,10 +148,12 @@ def master(model_path: str = default_cfgs["master"]["url"], **kwargs: Any) -> MA
136
148
  Args:
137
149
  ----
138
150
  model_path: path to onnx model file, defaults to url in default_cfgs
151
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
152
+ engine_cfg: configuration for the inference engine
139
153
  **kwargs: keywoard arguments passed to the MASTER architecture
140
154
 
141
155
  Returns:
142
156
  -------
143
157
  text recognition architecture
144
158
  """
145
- return _master("master", model_path, **kwargs)
159
+ return _master("master", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -11,7 +11,7 @@ from scipy.special import softmax
11
11
 
12
12
  from onnxtr.utils import VOCABS
13
13
 
14
- from ...engine import Engine
14
+ from ...engine import Engine, EngineConfig
15
15
  from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["PARSeq", "parseq"]
@@ -23,6 +23,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
23
23
  "input_shape": (3, 32, 128),
24
24
  "vocab": VOCABS["french"],
25
25
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/parseq-00b40714.onnx",
26
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/parseq_dynamic_8_bit-5b04d9f7.onnx",
26
27
  },
27
28
  }
28
29
 
@@ -32,7 +33,9 @@ class PARSeq(Engine):
32
33
 
33
34
  Args:
34
35
  ----
36
+ model_path: path to onnx model file
35
37
  vocab: vocabulary used for encoding
38
+ engine_cfg: configuration for the inference engine
36
39
  cfg: dictionary containing information about the model
37
40
  **kwargs: additional arguments to be passed to `Engine`
38
41
  """
@@ -41,10 +44,11 @@ class PARSeq(Engine):
41
44
  self,
42
45
  model_path: str,
43
46
  vocab: str,
47
+ engine_cfg: EngineConfig = EngineConfig(),
44
48
  cfg: Optional[Dict[str, Any]] = None,
45
49
  **kwargs: Any,
46
50
  ) -> None:
47
- super().__init__(url=model_path, **kwargs)
51
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
48
52
  self.vocab = vocab
49
53
  self.cfg = cfg
50
54
  self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
@@ -99,6 +103,8 @@ class PARSeqPostProcessor(RecognitionPostProcessor):
99
103
  def _parseq(
100
104
  arch: str,
101
105
  model_path: str,
106
+ load_in_8_bit: bool = False,
107
+ engine_cfg: EngineConfig = EngineConfig(),
102
108
  **kwargs: Any,
103
109
  ) -> PARSeq:
104
110
  # Patch the config
@@ -107,12 +113,19 @@ def _parseq(
107
113
  _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
108
114
 
109
115
  kwargs["vocab"] = _cfg["vocab"]
116
+ # Patch the url
117
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
110
118
 
111
119
  # Build the model
112
- return PARSeq(model_path, cfg=_cfg, **kwargs)
120
+ return PARSeq(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
113
121
 
114
122
 
115
- def parseq(model_path: str = default_cfgs["parseq"]["url"], **kwargs: Any) -> PARSeq:
123
+ def parseq(
124
+ model_path: str = default_cfgs["parseq"]["url"],
125
+ load_in_8_bit: bool = False,
126
+ engine_cfg: EngineConfig = EngineConfig(),
127
+ **kwargs: Any,
128
+ ) -> PARSeq:
116
129
  """PARSeq architecture from
117
130
  `"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
118
131
 
@@ -125,10 +138,12 @@ def parseq(model_path: str = default_cfgs["parseq"]["url"], **kwargs: Any) -> PA
125
138
  Args:
126
139
  ----
127
140
  model_path: path to onnx model file, defaults to url in default_cfgs
141
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
142
+ engine_cfg: configuration for the inference engine
128
143
  **kwargs: keyword arguments of the PARSeq architecture
129
144
 
130
145
  Returns:
131
146
  -------
132
147
  text recognition architecture
133
148
  """
134
- return _parseq("parseq", model_path, **kwargs)
149
+ return _parseq("parseq", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -11,7 +11,7 @@ from scipy.special import softmax
11
11
 
12
12
  from onnxtr.utils import VOCABS
13
13
 
14
- from ...engine import Engine
14
+ from ...engine import Engine, EngineConfig
15
15
  from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["SAR", "sar_resnet31"]
@@ -23,6 +23,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
23
23
  "input_shape": (3, 32, 128),
24
24
  "vocab": VOCABS["french"],
25
25
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/sar_resnet31-395f8005.onnx",
26
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/sar_resnet31_static_8_bit-c07316bc.onnx",
26
27
  },
27
28
  }
28
29
 
@@ -34,6 +35,7 @@ class SAR(Engine):
34
35
  ----
35
36
  model_path: path to onnx model file
36
37
  vocab: vocabulary used for encoding
38
+ engine_cfg: configuration for the inference engine
37
39
  cfg: dictionary containing information about the model
38
40
  **kwargs: additional arguments to be passed to `Engine`
39
41
  """
@@ -42,10 +44,11 @@ class SAR(Engine):
42
44
  self,
43
45
  model_path: str,
44
46
  vocab: str,
47
+ engine_cfg: EngineConfig = EngineConfig(),
45
48
  cfg: Optional[Dict[str, Any]] = None,
46
49
  **kwargs: Any,
47
50
  ) -> None:
48
- super().__init__(url=model_path, **kwargs)
51
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
49
52
  self.vocab = vocab
50
53
  self.cfg = cfg
51
54
  self.postprocessor = SARPostProcessor(self.vocab)
@@ -99,6 +102,8 @@ class SARPostProcessor(RecognitionPostProcessor):
99
102
  def _sar(
100
103
  arch: str,
101
104
  model_path: str,
105
+ load_in_8_bit: bool = False,
106
+ engine_cfg: EngineConfig = EngineConfig(),
102
107
  **kwargs: Any,
103
108
  ) -> SAR:
104
109
  # Patch the config
@@ -107,12 +112,19 @@ def _sar(
107
112
  _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
108
113
 
109
114
  kwargs["vocab"] = _cfg["vocab"]
115
+ # Patch the url
116
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
110
117
 
111
118
  # Build the model
112
- return SAR(model_path, cfg=_cfg, **kwargs)
119
+ return SAR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
113
120
 
114
121
 
115
- def sar_resnet31(model_path: str = default_cfgs["sar_resnet31"]["url"], **kwargs: Any) -> SAR:
122
+ def sar_resnet31(
123
+ model_path: str = default_cfgs["sar_resnet31"]["url"],
124
+ load_in_8_bit: bool = False,
125
+ engine_cfg: EngineConfig = EngineConfig(),
126
+ **kwargs: Any,
127
+ ) -> SAR:
116
128
  """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
117
129
  Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
118
130
 
@@ -125,10 +137,12 @@ def sar_resnet31(model_path: str = default_cfgs["sar_resnet31"]["url"], **kwargs
125
137
  Args:
126
138
  ----
127
139
  model_path: path to onnx model file, defaults to url in default_cfgs
140
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
141
+ engine_cfg: configuration for the inference engine
128
142
  **kwargs: keyword arguments of the SAR architecture
129
143
 
130
144
  Returns:
131
145
  -------
132
146
  text recognition architecture
133
147
  """
134
- return _sar("sar_resnet31", model_path, **kwargs)
148
+ return _sar("sar_resnet31", model_path, load_in_8_bit, engine_cfg, **kwargs)