onnxtr 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. onnxtr/contrib/base.py +1 -4
  2. onnxtr/io/elements.py +17 -4
  3. onnxtr/io/pdf.py +6 -3
  4. onnxtr/models/__init__.py +1 -0
  5. onnxtr/models/_utils.py +57 -20
  6. onnxtr/models/builder.py +24 -9
  7. onnxtr/models/classification/models/mobilenet.py +12 -5
  8. onnxtr/models/classification/zoo.py +20 -8
  9. onnxtr/models/detection/_utils/__init__.py +1 -0
  10. onnxtr/models/detection/_utils/base.py +66 -0
  11. onnxtr/models/detection/models/differentiable_binarization.py +27 -12
  12. onnxtr/models/detection/models/fast.py +30 -9
  13. onnxtr/models/detection/models/linknet.py +24 -9
  14. onnxtr/models/detection/postprocessor/base.py +4 -3
  15. onnxtr/models/detection/predictor/base.py +15 -1
  16. onnxtr/models/detection/zoo.py +14 -5
  17. onnxtr/models/engine.py +73 -7
  18. onnxtr/models/predictor/base.py +65 -42
  19. onnxtr/models/predictor/predictor.py +23 -16
  20. onnxtr/models/recognition/models/crnn.py +24 -9
  21. onnxtr/models/recognition/models/master.py +14 -5
  22. onnxtr/models/recognition/models/parseq.py +14 -5
  23. onnxtr/models/recognition/models/sar.py +12 -5
  24. onnxtr/models/recognition/models/vitstr.py +18 -7
  25. onnxtr/models/recognition/zoo.py +10 -7
  26. onnxtr/models/zoo.py +19 -3
  27. onnxtr/py.typed +0 -0
  28. onnxtr/utils/geometry.py +33 -12
  29. onnxtr/version.py +1 -1
  30. {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/METADATA +63 -24
  31. {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/RECORD +35 -32
  32. {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/WHEEL +1 -1
  33. {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/top_level.txt +0 -1
  34. {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/LICENSE +0 -0
  35. {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/zip-safe +0 -0
@@ -8,7 +8,7 @@ from typing import Any, Dict, Optional
8
8
  import numpy as np
9
9
  from scipy.special import expit
10
10
 
11
- from ...engine import Engine
11
+ from ...engine import Engine, EngineConfig
12
12
  from ..postprocessor.base import GeneralDetectionPostProcessor
13
13
 
14
14
  __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
@@ -33,8 +33,8 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
33
33
  "input_shape": (3, 1024, 1024),
34
34
  "mean": (0.798, 0.785, 0.772),
35
35
  "std": (0.264, 0.2749, 0.287),
36
- "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx",
37
- "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_mobilenet_v3_large_static_8_bit-51659bb9.onnx",
36
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large-4987e7bd.onnx",
37
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large_static_8_bit-535a6f25.onnx",
38
38
  },
39
39
  }
40
40
 
@@ -45,6 +45,7 @@ class DBNet(Engine):
45
45
  Args:
46
46
  ----
47
47
  model_path: path or url to onnx model file
48
+ engine_cfg: configuration for the inference engine
48
49
  bin_thresh: threshold for binarization of the output feature map
49
50
  box_thresh: minimal objectness score to consider a box
50
51
  assume_straight_pages: if True, fit straight bounding boxes only
@@ -54,14 +55,15 @@ class DBNet(Engine):
54
55
 
55
56
  def __init__(
56
57
  self,
57
- model_path,
58
+ model_path: str,
59
+ engine_cfg: Optional[EngineConfig] = None,
58
60
  bin_thresh: float = 0.3,
59
61
  box_thresh: float = 0.1,
60
62
  assume_straight_pages: bool = True,
61
63
  cfg: Optional[Dict[str, Any]] = None,
62
64
  **kwargs: Any,
63
65
  ) -> None:
64
- super().__init__(url=model_path, **kwargs)
66
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
65
67
  self.cfg = cfg
66
68
  self.assume_straight_pages = assume_straight_pages
67
69
  self.postprocessor = GeneralDetectionPostProcessor(
@@ -91,16 +93,20 @@ def _dbnet(
91
93
  arch: str,
92
94
  model_path: str,
93
95
  load_in_8_bit: bool = False,
96
+ engine_cfg: Optional[EngineConfig] = None,
94
97
  **kwargs: Any,
95
98
  ) -> DBNet:
96
99
  # Patch the url
97
100
  model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
98
101
  # Build the model
99
- return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
102
+ return DBNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
100
103
 
101
104
 
102
105
  def db_resnet34(
103
- model_path: str = default_cfgs["db_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any
106
+ model_path: str = default_cfgs["db_resnet34"]["url"],
107
+ load_in_8_bit: bool = False,
108
+ engine_cfg: Optional[EngineConfig] = None,
109
+ **kwargs: Any,
104
110
  ) -> DBNet:
105
111
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
106
112
  <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
@@ -115,17 +121,21 @@ def db_resnet34(
115
121
  ----
116
122
  model_path: path to onnx model file, defaults to url in default_cfgs
117
123
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
124
+ engine_cfg: configuration for the inference engine
118
125
  **kwargs: keyword arguments of the DBNet architecture
119
126
 
120
127
  Returns:
121
128
  -------
122
129
  text detection architecture
123
130
  """
124
- return _dbnet("db_resnet34", model_path, load_in_8_bit, **kwargs)
131
+ return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
125
132
 
126
133
 
127
134
  def db_resnet50(
128
- model_path: str = default_cfgs["db_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any
135
+ model_path: str = default_cfgs["db_resnet50"]["url"],
136
+ load_in_8_bit: bool = False,
137
+ engine_cfg: Optional[EngineConfig] = None,
138
+ **kwargs: Any,
129
139
  ) -> DBNet:
130
140
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
131
141
  <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
@@ -140,17 +150,21 @@ def db_resnet50(
140
150
  ----
141
151
  model_path: path to onnx model file, defaults to url in default_cfgs
142
152
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
153
+ engine_cfg: configuration for the inference engine
143
154
  **kwargs: keyword arguments of the DBNet architecture
144
155
 
145
156
  Returns:
146
157
  -------
147
158
  text detection architecture
148
159
  """
149
- return _dbnet("db_resnet50", model_path, load_in_8_bit, **kwargs)
160
+ return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
150
161
 
151
162
 
152
163
  def db_mobilenet_v3_large(
153
- model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any
164
+ model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
165
+ load_in_8_bit: bool = False,
166
+ engine_cfg: Optional[EngineConfig] = None,
167
+ **kwargs: Any,
154
168
  ) -> DBNet:
155
169
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
156
170
  <https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
@@ -165,10 +179,11 @@ def db_mobilenet_v3_large(
165
179
  ----
166
180
  model_path: path to onnx model file, defaults to url in default_cfgs
167
181
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
182
+ engine_cfg: configuration for the inference engine
168
183
  **kwargs: keyword arguments of the DBNet architecture
169
184
 
170
185
  Returns:
171
186
  -------
172
187
  text detection architecture
173
188
  """
174
- return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
189
+ return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -9,7 +9,7 @@ from typing import Any, Dict, Optional
9
9
  import numpy as np
10
10
  from scipy.special import expit
11
11
 
12
- from ...engine import Engine
12
+ from ...engine import Engine, EngineConfig
13
13
  from ..postprocessor.base import GeneralDetectionPostProcessor
14
14
 
15
15
  __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
@@ -43,6 +43,7 @@ class FAST(Engine):
43
43
  Args:
44
44
  ----
45
45
  model_path: path or url to onnx model file
46
+ engine_cfg: configuration for the inference engine
46
47
  bin_thresh: threshold for binarization of the output feature map
47
48
  box_thresh: minimal objectness score to consider a box
48
49
  assume_straight_pages: if True, fit straight bounding boxes only
@@ -53,13 +54,14 @@ class FAST(Engine):
53
54
  def __init__(
54
55
  self,
55
56
  model_path: str,
57
+ engine_cfg: Optional[EngineConfig] = None,
56
58
  bin_thresh: float = 0.1,
57
59
  box_thresh: float = 0.1,
58
60
  assume_straight_pages: bool = True,
59
61
  cfg: Optional[Dict[str, Any]] = None,
60
62
  **kwargs: Any,
61
63
  ) -> None:
62
- super().__init__(url=model_path, **kwargs)
64
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
63
65
  self.cfg = cfg
64
66
  self.assume_straight_pages = assume_straight_pages
65
67
 
@@ -90,15 +92,21 @@ def _fast(
90
92
  arch: str,
91
93
  model_path: str,
92
94
  load_in_8_bit: bool = False,
95
+ engine_cfg: Optional[EngineConfig] = None,
93
96
  **kwargs: Any,
94
97
  ) -> FAST:
95
98
  if load_in_8_bit:
96
99
  logging.warning("FAST models do not support 8-bit quantization yet. Loading full precision model...")
97
100
  # Build the model
98
- return FAST(model_path, cfg=default_cfgs[arch], **kwargs)
101
+ return FAST(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
99
102
 
100
103
 
101
- def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
104
+ def fast_tiny(
105
+ model_path: str = default_cfgs["fast_tiny"]["url"],
106
+ load_in_8_bit: bool = False,
107
+ engine_cfg: Optional[EngineConfig] = None,
108
+ **kwargs: Any,
109
+ ) -> FAST:
102
110
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
103
111
  <https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
104
112
 
@@ -112,16 +120,22 @@ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit:
112
120
  ----
113
121
  model_path: path to onnx model file, defaults to url in default_cfgs
114
122
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
123
+ engine_cfg: configuration for the inference engine
115
124
  **kwargs: keyword arguments of the DBNet architecture
116
125
 
117
126
  Returns:
118
127
  -------
119
128
  text detection architecture
120
129
  """
121
- return _fast("fast_tiny", model_path, load_in_8_bit, **kwargs)
130
+ return _fast("fast_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs)
122
131
 
123
132
 
124
- def fast_small(model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
133
+ def fast_small(
134
+ model_path: str = default_cfgs["fast_small"]["url"],
135
+ load_in_8_bit: bool = False,
136
+ engine_cfg: Optional[EngineConfig] = None,
137
+ **kwargs: Any,
138
+ ) -> FAST:
125
139
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
126
140
  <https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
127
141
 
@@ -135,16 +149,22 @@ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bi
135
149
  ----
136
150
  model_path: path to onnx model file, defaults to url in default_cfgs
137
151
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
152
+ engine_cfg: configuration for the inference engine
138
153
  **kwargs: keyword arguments of the DBNet architecture
139
154
 
140
155
  Returns:
141
156
  -------
142
157
  text detection architecture
143
158
  """
144
- return _fast("fast_small", model_path, load_in_8_bit, **kwargs)
159
+ return _fast("fast_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
145
160
 
146
161
 
147
- def fast_base(model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
162
+ def fast_base(
163
+ model_path: str = default_cfgs["fast_base"]["url"],
164
+ load_in_8_bit: bool = False,
165
+ engine_cfg: Optional[EngineConfig] = None,
166
+ **kwargs: Any,
167
+ ) -> FAST:
148
168
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
149
169
  <https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
150
170
 
@@ -158,10 +178,11 @@ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit:
158
178
  ----
159
179
  model_path: path to onnx model file, defaults to url in default_cfgs
160
180
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
181
+ engine_cfg: configuration for the inference engine
161
182
  **kwargs: keyword arguments of the DBNet architecture
162
183
 
163
184
  Returns:
164
185
  -------
165
186
  text detection architecture
166
187
  """
167
- return _fast("fast_base", model_path, load_in_8_bit, **kwargs)
188
+ return _fast("fast_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -8,7 +8,7 @@ from typing import Any, Dict, Optional
8
8
  import numpy as np
9
9
  from scipy.special import expit
10
10
 
11
- from ...engine import Engine
11
+ from ...engine import Engine, EngineConfig
12
12
  from ..postprocessor.base import GeneralDetectionPostProcessor
13
13
 
14
14
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
@@ -45,6 +45,7 @@ class LinkNet(Engine):
45
45
  Args:
46
46
  ----
47
47
  model_path: path or url to onnx model file
48
+ engine_cfg: configuration for the inference engine
48
49
  bin_thresh: threshold for binarization of the output feature map
49
50
  box_thresh: minimal objectness score to consider a box
50
51
  assume_straight_pages: if True, fit straight bounding boxes only
@@ -55,13 +56,14 @@ class LinkNet(Engine):
55
56
  def __init__(
56
57
  self,
57
58
  model_path: str,
59
+ engine_cfg: Optional[EngineConfig] = None,
58
60
  bin_thresh: float = 0.1,
59
61
  box_thresh: float = 0.1,
60
62
  assume_straight_pages: bool = True,
61
63
  cfg: Optional[Dict[str, Any]] = None,
62
64
  **kwargs: Any,
63
65
  ) -> None:
64
- super().__init__(url=model_path, **kwargs)
66
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
65
67
  self.cfg = cfg
66
68
  self.assume_straight_pages = assume_straight_pages
67
69
 
@@ -92,16 +94,20 @@ def _linknet(
92
94
  arch: str,
93
95
  model_path: str,
94
96
  load_in_8_bit: bool = False,
97
+ engine_cfg: Optional[EngineConfig] = None,
95
98
  **kwargs: Any,
96
99
  ) -> LinkNet:
97
100
  # Patch the url
98
101
  model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
99
102
  # Build the model
100
- return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs)
103
+ return LinkNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
101
104
 
102
105
 
103
106
  def linknet_resnet18(
104
- model_path: str = default_cfgs["linknet_resnet18"]["url"], load_in_8_bit: bool = False, **kwargs: Any
107
+ model_path: str = default_cfgs["linknet_resnet18"]["url"],
108
+ load_in_8_bit: bool = False,
109
+ engine_cfg: Optional[EngineConfig] = None,
110
+ **kwargs: Any,
105
111
  ) -> LinkNet:
106
112
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
107
113
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
@@ -116,17 +122,21 @@ def linknet_resnet18(
116
122
  ----
117
123
  model_path: path to onnx model file, defaults to url in default_cfgs
118
124
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
125
+ engine_cfg: configuration for the inference engine
119
126
  **kwargs: keyword arguments of the LinkNet architecture
120
127
 
121
128
  Returns:
122
129
  -------
123
130
  text detection architecture
124
131
  """
125
- return _linknet("linknet_resnet18", model_path, load_in_8_bit, **kwargs)
132
+ return _linknet("linknet_resnet18", model_path, load_in_8_bit, engine_cfg, **kwargs)
126
133
 
127
134
 
128
135
  def linknet_resnet34(
129
- model_path: str = default_cfgs["linknet_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any
136
+ model_path: str = default_cfgs["linknet_resnet34"]["url"],
137
+ load_in_8_bit: bool = False,
138
+ engine_cfg: Optional[EngineConfig] = None,
139
+ **kwargs: Any,
130
140
  ) -> LinkNet:
131
141
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
132
142
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
@@ -141,17 +151,21 @@ def linknet_resnet34(
141
151
  ----
142
152
  model_path: path to onnx model file, defaults to url in default_cfgs
143
153
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
154
+ engine_cfg: configuration for the inference engine
144
155
  **kwargs: keyword arguments of the LinkNet architecture
145
156
 
146
157
  Returns:
147
158
  -------
148
159
  text detection architecture
149
160
  """
150
- return _linknet("linknet_resnet34", model_path, load_in_8_bit, **kwargs)
161
+ return _linknet("linknet_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
151
162
 
152
163
 
153
164
  def linknet_resnet50(
154
- model_path: str = default_cfgs["linknet_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any
165
+ model_path: str = default_cfgs["linknet_resnet50"]["url"],
166
+ load_in_8_bit: bool = False,
167
+ engine_cfg: Optional[EngineConfig] = None,
168
+ **kwargs: Any,
155
169
  ) -> LinkNet:
156
170
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
157
171
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
@@ -166,10 +180,11 @@ def linknet_resnet50(
166
180
  ----
167
181
  model_path: path to onnx model file, defaults to url in default_cfgs
168
182
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
183
+ engine_cfg: configuration for the inference engine
169
184
  **kwargs: keyword arguments of the LinkNet architecture
170
185
 
171
186
  Returns:
172
187
  -------
173
188
  text detection architecture
174
189
  """
175
- return _linknet("linknet_resnet50", model_path, load_in_8_bit, **kwargs)
190
+ return _linknet("linknet_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -109,7 +109,7 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
109
109
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
110
110
  for contour in contours:
111
111
  # Check whether smallest enclosing bounding box is not too small
112
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
112
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
113
113
  continue
114
114
  # Compute objectness
115
115
  if self.assume_straight_pages:
@@ -136,9 +136,10 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
136
136
  # compute relative box to get rid of img shape
137
137
  _box[:, 0] /= width
138
138
  _box[:, 1] /= height
139
- boxes.append(_box)
139
+ # Add score to box as (0, score)
140
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
140
141
 
141
142
  if not self.assume_straight_pages:
142
- return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
143
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
143
144
  else:
144
145
  return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
@@ -7,6 +7,7 @@ from typing import Any, List, Tuple, Union
7
7
 
8
8
  import numpy as np
9
9
 
10
+ from onnxtr.models.detection._utils import _remove_padding
10
11
  from onnxtr.models.preprocessor import PreProcessor
11
12
  from onnxtr.utils.repr import NestedObject
12
13
 
@@ -38,6 +39,11 @@ class DetectionPredictor(NestedObject):
38
39
  return_maps: bool = False,
39
40
  **kwargs: Any,
40
41
  ) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]:
42
+ # Extract parameters from the preprocessor
43
+ preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
44
+ symmetric_pad = self.pre_processor.resize.symmetric_pad
45
+ assume_straight_pages = self.model.assume_straight_pages
46
+
41
47
  # Dimension check
42
48
  if any(page.ndim != 3 for page in pages):
43
49
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
@@ -47,7 +53,15 @@ class DetectionPredictor(NestedObject):
47
53
  self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
48
54
  ]
49
55
 
50
- preds = [pred for batch in predicted_batches for pred in batch["preds"]]
56
+ # Remove padding from loc predictions
57
+ preds = _remove_padding(
58
+ pages,
59
+ [pred[0] for batch in predicted_batches for pred in batch["preds"]],
60
+ preserve_aspect_ratio=preserve_aspect_ratio,
61
+ symmetric_pad=symmetric_pad,
62
+ assume_straight_pages=assume_straight_pages,
63
+ )
64
+
51
65
  if return_maps:
52
66
  seg_maps = [pred for batch in predicted_batches for pred in batch["out_map"]]
53
67
  return preds, seg_maps
@@ -3,9 +3,10 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any
6
+ from typing import Any, Optional
7
7
 
8
8
  from .. import detection
9
+ from ..engine import EngineConfig
9
10
  from ..preprocessor import PreProcessor
10
11
  from .predictor import DetectionPredictor
11
12
 
@@ -25,13 +26,19 @@ ARCHS = [
25
26
 
26
27
 
27
28
  def _predictor(
28
- arch: Any, assume_straight_pages: bool = True, load_in_8_bit: bool = False, **kwargs: Any
29
+ arch: Any,
30
+ assume_straight_pages: bool = True,
31
+ load_in_8_bit: bool = False,
32
+ engine_cfg: Optional[EngineConfig] = None,
33
+ **kwargs: Any,
29
34
  ) -> DetectionPredictor:
30
35
  if isinstance(arch, str):
31
36
  if arch not in ARCHS:
32
37
  raise ValueError(f"unknown architecture '{arch}'")
33
38
 
34
- _model = detection.__dict__[arch](assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit)
39
+ _model = detection.__dict__[arch](
40
+ assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg
41
+ )
35
42
  else:
36
43
  if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
37
44
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -41,7 +48,7 @@ def _predictor(
41
48
 
42
49
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
43
50
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
44
- kwargs["batch_size"] = kwargs.get("batch_size", 4)
51
+ kwargs["batch_size"] = kwargs.get("batch_size", 2)
45
52
  predictor = DetectionPredictor(
46
53
  PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
47
54
  _model,
@@ -53,6 +60,7 @@ def detection_predictor(
53
60
  arch: Any = "fast_base",
54
61
  assume_straight_pages: bool = True,
55
62
  load_in_8_bit: bool = False,
63
+ engine_cfg: Optional[EngineConfig] = None,
56
64
  **kwargs: Any,
57
65
  ) -> DetectionPredictor:
58
66
  """Text detection architecture.
@@ -68,10 +76,11 @@ def detection_predictor(
68
76
  arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
69
77
  assume_straight_pages: If True, fit straight boxes to the page
70
78
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
79
+ engine_cfg: configuration for the inference engine
71
80
  **kwargs: optional keyword arguments passed to the architecture
72
81
 
73
82
  Returns:
74
83
  -------
75
84
  Detection predictor
76
85
  """
77
- return _predictor(arch, assume_straight_pages, load_in_8_bit, **kwargs)
86
+ return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs)
onnxtr/models/engine.py CHANGED
@@ -3,14 +3,79 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Union
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import numpy as np
9
- import onnxruntime
9
+ from onnxruntime import (
10
+ ExecutionMode,
11
+ GraphOptimizationLevel,
12
+ InferenceSession,
13
+ SessionOptions,
14
+ get_available_providers,
15
+ get_device,
16
+ )
10
17
 
11
18
  from onnxtr.utils.data import download_from_url
12
19
  from onnxtr.utils.geometry import shape_translate
13
20
 
21
+ __all__ = ["EngineConfig"]
22
+
23
+
24
+ class EngineConfig:
25
+ """Implements a configuration class for the engine of a model
26
+
27
+ Args:
28
+ ----
29
+ providers: list of providers to use for inference ref.: https://onnxruntime.ai/docs/execution-providers/
30
+ session_options: configuration for the inference session ref.: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ providers: Optional[Union[List[Tuple[str, Dict[str, Any]]], List[str]]] = None,
36
+ session_options: Optional[SessionOptions] = None,
37
+ ):
38
+ self._providers = providers or self._init_providers()
39
+ self._session_options = session_options or self._init_sess_opts()
40
+
41
+ def _init_providers(self) -> List[Tuple[str, Dict[str, Any]]]:
42
+ providers: Any = [("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
43
+ available_providers = get_available_providers()
44
+ if "CUDAExecutionProvider" in available_providers and get_device() == "GPU": # pragma: no cover
45
+ providers.insert(
46
+ 0,
47
+ (
48
+ "CUDAExecutionProvider",
49
+ {
50
+ "device_id": 0,
51
+ "arena_extend_strategy": "kNextPowerOfTwo",
52
+ "cudnn_conv_algo_search": "DEFAULT",
53
+ "do_copy_in_default_stream": True,
54
+ },
55
+ ),
56
+ )
57
+ return providers
58
+
59
+ def _init_sess_opts(self) -> SessionOptions:
60
+ session_options = SessionOptions()
61
+ session_options.enable_cpu_mem_arena = True
62
+ session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
63
+ session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
64
+ session_options.intra_op_num_threads = -1
65
+ session_options.inter_op_num_threads = -1
66
+ return session_options
67
+
68
+ @property
69
+ def providers(self) -> Union[List[Tuple[str, Dict[str, Any]]], List[str]]:
70
+ return self._providers
71
+
72
+ @property
73
+ def session_options(self) -> SessionOptions:
74
+ return self._session_options
75
+
76
+ def __repr__(self) -> str:
77
+ return f"EngineConfig(providers={self.providers}"
78
+
14
79
 
15
80
  class Engine:
16
81
  """Implements an abstract class for the engine of a model
@@ -18,15 +83,16 @@ class Engine:
18
83
  Args:
19
84
  ----
20
85
  url: the url to use to download a model if needed
21
- providers: list of providers to use for inference
86
+ engine_cfg: the configuration of the engine
22
87
  **kwargs: additional arguments to be passed to `download_from_url`
23
88
  """
24
89
 
25
- def __init__(
26
- self, url: str, providers: List[str] = ["CPUExecutionProvider", "CUDAExecutionProvider"], **kwargs: Any
27
- ) -> None:
90
+ def __init__(self, url: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any) -> None:
91
+ engine_cfg = engine_cfg if isinstance(engine_cfg, EngineConfig) else EngineConfig()
28
92
  archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
29
- self.runtime = onnxruntime.InferenceSession(archive_path, providers=providers)
93
+ self.session_options = engine_cfg.session_options
94
+ self.providers = engine_cfg.providers
95
+ self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options)
30
96
  self.runtime_inputs = self.runtime.get_inputs()[0]
31
97
  self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3
32
98
  self.fixed_batch_size: Union[int, str] = self.runtime_inputs.shape[