onnxtr 0.1.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. onnxtr/io/elements.py +17 -4
  2. onnxtr/io/pdf.py +6 -3
  3. onnxtr/models/__init__.py +1 -0
  4. onnxtr/models/_utils.py +57 -20
  5. onnxtr/models/builder.py +24 -9
  6. onnxtr/models/classification/models/mobilenet.py +25 -7
  7. onnxtr/models/classification/predictor/base.py +1 -0
  8. onnxtr/models/classification/zoo.py +22 -7
  9. onnxtr/models/detection/_utils/__init__.py +1 -0
  10. onnxtr/models/detection/_utils/base.py +66 -0
  11. onnxtr/models/detection/models/differentiable_binarization.py +41 -11
  12. onnxtr/models/detection/models/fast.py +37 -9
  13. onnxtr/models/detection/models/linknet.py +39 -9
  14. onnxtr/models/detection/postprocessor/base.py +4 -3
  15. onnxtr/models/detection/predictor/base.py +15 -1
  16. onnxtr/models/detection/zoo.py +16 -3
  17. onnxtr/models/engine.py +75 -9
  18. onnxtr/models/predictor/base.py +69 -42
  19. onnxtr/models/predictor/predictor.py +22 -15
  20. onnxtr/models/recognition/models/crnn.py +39 -9
  21. onnxtr/models/recognition/models/master.py +19 -5
  22. onnxtr/models/recognition/models/parseq.py +20 -5
  23. onnxtr/models/recognition/models/sar.py +19 -5
  24. onnxtr/models/recognition/models/vitstr.py +31 -9
  25. onnxtr/models/recognition/zoo.py +12 -6
  26. onnxtr/models/zoo.py +22 -0
  27. onnxtr/py.typed +0 -0
  28. onnxtr/utils/geometry.py +33 -12
  29. onnxtr/version.py +1 -1
  30. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/METADATA +81 -16
  31. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/RECORD +35 -32
  32. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/WHEEL +1 -1
  33. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/top_level.txt +0 -1
  34. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/LICENSE +0 -0
  35. {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/zip-safe +0 -0
@@ -8,7 +8,7 @@ from typing import Any, Dict, Optional
8
8
  import numpy as np
9
9
  from scipy.special import expit
10
10
 
11
- from ...engine import Engine
11
+ from ...engine import Engine, EngineConfig
12
12
  from ..postprocessor.base import GeneralDetectionPostProcessor
13
13
 
14
14
  __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
@@ -20,18 +20,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
20
20
  "mean": (0.798, 0.785, 0.772),
21
21
  "std": (0.264, 0.2749, 0.287),
22
22
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet50-69ba0015.onnx",
23
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet50_static_8_bit-09a6104f.onnx",
23
24
  },
24
25
  "db_resnet34": {
25
26
  "input_shape": (3, 1024, 1024),
26
27
  "mean": (0.798, 0.785, 0.772),
27
28
  "std": (0.264, 0.2749, 0.287),
28
29
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet34-b4873198.onnx",
30
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet34_static_8_bit-027e2c7f.onnx",
29
31
  },
30
32
  "db_mobilenet_v3_large": {
31
33
  "input_shape": (3, 1024, 1024),
32
34
  "mean": (0.798, 0.785, 0.772),
33
35
  "std": (0.264, 0.2749, 0.287),
34
- "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx",
36
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large-4987e7bd.onnx",
37
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large_static_8_bit-535a6f25.onnx",
35
38
  },
36
39
  }
37
40
 
@@ -42,6 +45,7 @@ class DBNet(Engine):
42
45
  Args:
43
46
  ----
44
47
  model_path: path or url to onnx model file
48
+ engine_cfg: configuration for the inference engine
45
49
  bin_thresh: threshold for binarization of the output feature map
46
50
  box_thresh: minimal objectness score to consider a box
47
51
  assume_straight_pages: if True, fit straight bounding boxes only
@@ -51,14 +55,15 @@ class DBNet(Engine):
51
55
 
52
56
  def __init__(
53
57
  self,
54
- model_path,
58
+ model_path: str,
59
+ engine_cfg: EngineConfig = EngineConfig(),
55
60
  bin_thresh: float = 0.3,
56
61
  box_thresh: float = 0.1,
57
62
  assume_straight_pages: bool = True,
58
63
  cfg: Optional[Dict[str, Any]] = None,
59
64
  **kwargs: Any,
60
65
  ) -> None:
61
- super().__init__(url=model_path, **kwargs)
66
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
62
67
  self.cfg = cfg
63
68
  self.assume_straight_pages = assume_straight_pages
64
69
  self.postprocessor = GeneralDetectionPostProcessor(
@@ -87,13 +92,22 @@ class DBNet(Engine):
87
92
  def _dbnet(
88
93
  arch: str,
89
94
  model_path: str,
95
+ load_in_8_bit: bool = False,
96
+ engine_cfg: EngineConfig = EngineConfig(),
90
97
  **kwargs: Any,
91
98
  ) -> DBNet:
99
+ # Patch the url
100
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
92
101
  # Build the model
93
- return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
102
+ return DBNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
94
103
 
95
104
 
96
- def db_resnet34(model_path: str = default_cfgs["db_resnet34"]["url"], **kwargs: Any) -> DBNet:
105
+ def db_resnet34(
106
+ model_path: str = default_cfgs["db_resnet34"]["url"],
107
+ load_in_8_bit: bool = False,
108
+ engine_cfg: EngineConfig = EngineConfig(),
109
+ **kwargs: Any,
110
+ ) -> DBNet:
97
111
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
98
112
  <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
99
113
 
@@ -106,16 +120,23 @@ def db_resnet34(model_path: str = default_cfgs["db_resnet34"]["url"], **kwargs:
106
120
  Args:
107
121
  ----
108
122
  model_path: path to onnx model file, defaults to url in default_cfgs
123
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
124
+ engine_cfg: configuration for the inference engine
109
125
  **kwargs: keyword arguments of the DBNet architecture
110
126
 
111
127
  Returns:
112
128
  -------
113
129
  text detection architecture
114
130
  """
115
- return _dbnet("db_resnet34", model_path, **kwargs)
131
+ return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
116
132
 
117
133
 
118
- def db_resnet50(model_path: str = default_cfgs["db_resnet50"]["url"], **kwargs: Any) -> DBNet:
134
+ def db_resnet50(
135
+ model_path: str = default_cfgs["db_resnet50"]["url"],
136
+ load_in_8_bit: bool = False,
137
+ engine_cfg: EngineConfig = EngineConfig(),
138
+ **kwargs: Any,
139
+ ) -> DBNet:
119
140
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
120
141
  <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
121
142
 
@@ -128,16 +149,23 @@ def db_resnet50(model_path: str = default_cfgs["db_resnet50"]["url"], **kwargs:
128
149
  Args:
129
150
  ----
130
151
  model_path: path to onnx model file, defaults to url in default_cfgs
152
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
153
+ engine_cfg: configuration for the inference engine
131
154
  **kwargs: keyword arguments of the DBNet architecture
132
155
 
133
156
  Returns:
134
157
  -------
135
158
  text detection architecture
136
159
  """
137
- return _dbnet("db_resnet50", model_path, **kwargs)
160
+ return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
138
161
 
139
162
 
140
- def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], **kwargs: Any) -> DBNet:
163
+ def db_mobilenet_v3_large(
164
+ model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
165
+ load_in_8_bit: bool = False,
166
+ engine_cfg: EngineConfig = EngineConfig(),
167
+ **kwargs: Any,
168
+ ) -> DBNet:
141
169
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
142
170
  <https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
143
171
 
@@ -150,10 +178,12 @@ def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"
150
178
  Args:
151
179
  ----
152
180
  model_path: path to onnx model file, defaults to url in default_cfgs
181
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
182
+ engine_cfg: configuration for the inference engine
153
183
  **kwargs: keyword arguments of the DBNet architecture
154
184
 
155
185
  Returns:
156
186
  -------
157
187
  text detection architecture
158
188
  """
159
- return _dbnet("db_mobilenet_v3_large", model_path, **kwargs)
189
+ return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -3,12 +3,13 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ import logging
6
7
  from typing import Any, Dict, Optional
7
8
 
8
9
  import numpy as np
9
10
  from scipy.special import expit
10
11
 
11
- from ...engine import Engine
12
+ from ...engine import Engine, EngineConfig
12
13
  from ..postprocessor.base import GeneralDetectionPostProcessor
13
14
 
14
15
  __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
@@ -42,6 +43,7 @@ class FAST(Engine):
42
43
  Args:
43
44
  ----
44
45
  model_path: path or url to onnx model file
46
+ engine_cfg: configuration for the inference engine
45
47
  bin_thresh: threshold for binarization of the output feature map
46
48
  box_thresh: minimal objectness score to consider a box
47
49
  assume_straight_pages: if True, fit straight bounding boxes only
@@ -52,13 +54,14 @@ class FAST(Engine):
52
54
  def __init__(
53
55
  self,
54
56
  model_path: str,
57
+ engine_cfg: EngineConfig = EngineConfig(),
55
58
  bin_thresh: float = 0.1,
56
59
  box_thresh: float = 0.1,
57
60
  assume_straight_pages: bool = True,
58
61
  cfg: Optional[Dict[str, Any]] = None,
59
62
  **kwargs: Any,
60
63
  ) -> None:
61
- super().__init__(url=model_path, **kwargs)
64
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
62
65
  self.cfg = cfg
63
66
  self.assume_straight_pages = assume_straight_pages
64
67
 
@@ -88,13 +91,22 @@ class FAST(Engine):
88
91
  def _fast(
89
92
  arch: str,
90
93
  model_path: str,
94
+ load_in_8_bit: bool = False,
95
+ engine_cfg: EngineConfig = EngineConfig(),
91
96
  **kwargs: Any,
92
97
  ) -> FAST:
98
+ if load_in_8_bit:
99
+ logging.warning("FAST models do not support 8-bit quantization yet. Loading full precision model...")
93
100
  # Build the model
94
- return FAST(model_path, cfg=default_cfgs[arch], **kwargs)
101
+ return FAST(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
95
102
 
96
103
 
97
- def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], **kwargs: Any) -> FAST:
104
+ def fast_tiny(
105
+ model_path: str = default_cfgs["fast_tiny"]["url"],
106
+ load_in_8_bit: bool = False,
107
+ engine_cfg: EngineConfig = EngineConfig(),
108
+ **kwargs: Any,
109
+ ) -> FAST:
98
110
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
99
111
  <https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
100
112
 
@@ -107,16 +119,23 @@ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], **kwargs: Any)
107
119
  Args:
108
120
  ----
109
121
  model_path: path to onnx model file, defaults to url in default_cfgs
122
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
123
+ engine_cfg: configuration for the inference engine
110
124
  **kwargs: keyword arguments of the DBNet architecture
111
125
 
112
126
  Returns:
113
127
  -------
114
128
  text detection architecture
115
129
  """
116
- return _fast("fast_tiny", model_path, **kwargs)
130
+ return _fast("fast_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs)
117
131
 
118
132
 
119
- def fast_small(model_path: str = default_cfgs["fast_small"]["url"], **kwargs: Any) -> FAST:
133
+ def fast_small(
134
+ model_path: str = default_cfgs["fast_small"]["url"],
135
+ load_in_8_bit: bool = False,
136
+ engine_cfg: EngineConfig = EngineConfig(),
137
+ **kwargs: Any,
138
+ ) -> FAST:
120
139
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
121
140
  <https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
122
141
 
@@ -129,16 +148,23 @@ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], **kwargs: An
129
148
  Args:
130
149
  ----
131
150
  model_path: path to onnx model file, defaults to url in default_cfgs
151
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
152
+ engine_cfg: configuration for the inference engine
132
153
  **kwargs: keyword arguments of the DBNet architecture
133
154
 
134
155
  Returns:
135
156
  -------
136
157
  text detection architecture
137
158
  """
138
- return _fast("fast_small", model_path, **kwargs)
159
+ return _fast("fast_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
139
160
 
140
161
 
141
- def fast_base(model_path: str = default_cfgs["fast_base"]["url"], **kwargs: Any) -> FAST:
162
+ def fast_base(
163
+ model_path: str = default_cfgs["fast_base"]["url"],
164
+ load_in_8_bit: bool = False,
165
+ engine_cfg: EngineConfig = EngineConfig(),
166
+ **kwargs: Any,
167
+ ) -> FAST:
142
168
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
143
169
  <https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
144
170
 
@@ -151,10 +177,12 @@ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], **kwargs: Any)
151
177
  Args:
152
178
  ----
153
179
  model_path: path to onnx model file, defaults to url in default_cfgs
180
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
181
+ engine_cfg: configuration for the inference engine
154
182
  **kwargs: keyword arguments of the DBNet architecture
155
183
 
156
184
  Returns:
157
185
  -------
158
186
  text detection architecture
159
187
  """
160
- return _fast("fast_base", model_path, **kwargs)
188
+ return _fast("fast_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -8,7 +8,7 @@ from typing import Any, Dict, Optional
8
8
  import numpy as np
9
9
  from scipy.special import expit
10
10
 
11
- from ...engine import Engine
11
+ from ...engine import Engine, EngineConfig
12
12
  from ..postprocessor.base import GeneralDetectionPostProcessor
13
13
 
14
14
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
@@ -20,18 +20,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
20
20
  "mean": (0.798, 0.785, 0.772),
21
21
  "std": (0.264, 0.2749, 0.287),
22
22
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet18-e0e0b9dc.onnx",
23
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/linknet_resnet18_static_8_bit-3b3a37dd.onnx",
23
24
  },
24
25
  "linknet_resnet34": {
25
26
  "input_shape": (3, 1024, 1024),
26
27
  "mean": (0.798, 0.785, 0.772),
27
28
  "std": (0.264, 0.2749, 0.287),
28
29
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet34-93e39a39.onnx",
30
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/linknet_resnet34_static_8_bit-2824329d.onnx",
29
31
  },
30
32
  "linknet_resnet50": {
31
33
  "input_shape": (3, 1024, 1024),
32
34
  "mean": (0.798, 0.785, 0.772),
33
35
  "std": (0.264, 0.2749, 0.287),
34
36
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet50-15d8c4ec.onnx",
37
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/linknet_resnet50_static_8_bit-65d6b0b8.onnx",
35
38
  },
36
39
  }
37
40
 
@@ -42,6 +45,7 @@ class LinkNet(Engine):
42
45
  Args:
43
46
  ----
44
47
  model_path: path or url to onnx model file
48
+ engine_cfg: configuration for the inference engine
45
49
  bin_thresh: threshold for binarization of the output feature map
46
50
  box_thresh: minimal objectness score to consider a box
47
51
  assume_straight_pages: if True, fit straight bounding boxes only
@@ -52,13 +56,14 @@ class LinkNet(Engine):
52
56
  def __init__(
53
57
  self,
54
58
  model_path: str,
59
+ engine_cfg: EngineConfig = EngineConfig(),
55
60
  bin_thresh: float = 0.1,
56
61
  box_thresh: float = 0.1,
57
62
  assume_straight_pages: bool = True,
58
63
  cfg: Optional[Dict[str, Any]] = None,
59
64
  **kwargs: Any,
60
65
  ) -> None:
61
- super().__init__(url=model_path, **kwargs)
66
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
62
67
  self.cfg = cfg
63
68
  self.assume_straight_pages = assume_straight_pages
64
69
 
@@ -88,13 +93,22 @@ class LinkNet(Engine):
88
93
  def _linknet(
89
94
  arch: str,
90
95
  model_path: str,
96
+ load_in_8_bit: bool = False,
97
+ engine_cfg: EngineConfig = EngineConfig(),
91
98
  **kwargs: Any,
92
99
  ) -> LinkNet:
100
+ # Patch the url
101
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
93
102
  # Build the model
94
- return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs)
103
+ return LinkNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
95
104
 
96
105
 
97
- def linknet_resnet18(model_path: str = default_cfgs["linknet_resnet18"]["url"], **kwargs: Any) -> LinkNet:
106
+ def linknet_resnet18(
107
+ model_path: str = default_cfgs["linknet_resnet18"]["url"],
108
+ load_in_8_bit: bool = False,
109
+ engine_cfg: EngineConfig = EngineConfig(),
110
+ **kwargs: Any,
111
+ ) -> LinkNet:
98
112
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
99
113
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
100
114
 
@@ -107,16 +121,23 @@ def linknet_resnet18(model_path: str = default_cfgs["linknet_resnet18"]["url"],
107
121
  Args:
108
122
  ----
109
123
  model_path: path to onnx model file, defaults to url in default_cfgs
124
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
125
+ engine_cfg: configuration for the inference engine
110
126
  **kwargs: keyword arguments of the LinkNet architecture
111
127
 
112
128
  Returns:
113
129
  -------
114
130
  text detection architecture
115
131
  """
116
- return _linknet("linknet_resnet18", model_path, **kwargs)
132
+ return _linknet("linknet_resnet18", model_path, load_in_8_bit, engine_cfg, **kwargs)
117
133
 
118
134
 
119
- def linknet_resnet34(model_path: str = default_cfgs["linknet_resnet34"]["url"], **kwargs: Any) -> LinkNet:
135
+ def linknet_resnet34(
136
+ model_path: str = default_cfgs["linknet_resnet34"]["url"],
137
+ load_in_8_bit: bool = False,
138
+ engine_cfg: EngineConfig = EngineConfig(),
139
+ **kwargs: Any,
140
+ ) -> LinkNet:
120
141
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
121
142
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
122
143
 
@@ -129,16 +150,23 @@ def linknet_resnet34(model_path: str = default_cfgs["linknet_resnet34"]["url"],
129
150
  Args:
130
151
  ----
131
152
  model_path: path to onnx model file, defaults to url in default_cfgs
153
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
154
+ engine_cfg: configuration for the inference engine
132
155
  **kwargs: keyword arguments of the LinkNet architecture
133
156
 
134
157
  Returns:
135
158
  -------
136
159
  text detection architecture
137
160
  """
138
- return _linknet("linknet_resnet34", model_path, **kwargs)
161
+ return _linknet("linknet_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
139
162
 
140
163
 
141
- def linknet_resnet50(model_path: str = default_cfgs["linknet_resnet50"]["url"], **kwargs: Any) -> LinkNet:
164
+ def linknet_resnet50(
165
+ model_path: str = default_cfgs["linknet_resnet50"]["url"],
166
+ load_in_8_bit: bool = False,
167
+ engine_cfg: EngineConfig = EngineConfig(),
168
+ **kwargs: Any,
169
+ ) -> LinkNet:
142
170
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
143
171
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
144
172
 
@@ -151,10 +179,12 @@ def linknet_resnet50(model_path: str = default_cfgs["linknet_resnet50"]["url"],
151
179
  Args:
152
180
  ----
153
181
  model_path: path to onnx model file, defaults to url in default_cfgs
182
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
183
+ engine_cfg: configuration for the inference engine
154
184
  **kwargs: keyword arguments of the LinkNet architecture
155
185
 
156
186
  Returns:
157
187
  -------
158
188
  text detection architecture
159
189
  """
160
- return _linknet("linknet_resnet50", model_path, **kwargs)
190
+ return _linknet("linknet_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -109,7 +109,7 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
109
109
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
110
110
  for contour in contours:
111
111
  # Check whether smallest enclosing bounding box is not too small
112
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
112
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
113
113
  continue
114
114
  # Compute objectness
115
115
  if self.assume_straight_pages:
@@ -136,9 +136,10 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
136
136
  # compute relative box to get rid of img shape
137
137
  _box[:, 0] /= width
138
138
  _box[:, 1] /= height
139
- boxes.append(_box)
139
+ # Add score to box as (0, score)
140
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
140
141
 
141
142
  if not self.assume_straight_pages:
142
- return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
143
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
143
144
  else:
144
145
  return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
@@ -7,6 +7,7 @@ from typing import Any, List, Tuple, Union
7
7
 
8
8
  import numpy as np
9
9
 
10
+ from onnxtr.models.detection._utils import _remove_padding
10
11
  from onnxtr.models.preprocessor import PreProcessor
11
12
  from onnxtr.utils.repr import NestedObject
12
13
 
@@ -38,6 +39,11 @@ class DetectionPredictor(NestedObject):
38
39
  return_maps: bool = False,
39
40
  **kwargs: Any,
40
41
  ) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]:
42
+ # Extract parameters from the preprocessor
43
+ preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
44
+ symmetric_pad = self.pre_processor.resize.symmetric_pad
45
+ assume_straight_pages = self.model.assume_straight_pages
46
+
41
47
  # Dimension check
42
48
  if any(page.ndim != 3 for page in pages):
43
49
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
@@ -47,7 +53,15 @@ class DetectionPredictor(NestedObject):
47
53
  self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
48
54
  ]
49
55
 
50
- preds = [pred for batch in predicted_batches for pred in batch["preds"]]
56
+ # Remove padding from loc predictions
57
+ preds = _remove_padding(
58
+ pages,
59
+ [pred[0] for batch in predicted_batches for pred in batch["preds"]],
60
+ preserve_aspect_ratio=preserve_aspect_ratio,
61
+ symmetric_pad=symmetric_pad,
62
+ assume_straight_pages=assume_straight_pages,
63
+ )
64
+
51
65
  if return_maps:
52
66
  seg_maps = [pred for batch in predicted_batches for pred in batch["out_map"]]
53
67
  return preds, seg_maps
@@ -6,6 +6,7 @@
6
6
  from typing import Any
7
7
 
8
8
  from .. import detection
9
+ from ..engine import EngineConfig
9
10
  from ..preprocessor import PreProcessor
10
11
  from .predictor import DetectionPredictor
11
12
 
@@ -24,12 +25,20 @@ ARCHS = [
24
25
  ]
25
26
 
26
27
 
27
- def _predictor(arch: Any, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
28
+ def _predictor(
29
+ arch: Any,
30
+ assume_straight_pages: bool = True,
31
+ load_in_8_bit: bool = False,
32
+ engine_cfg: EngineConfig = EngineConfig(),
33
+ **kwargs: Any,
34
+ ) -> DetectionPredictor:
28
35
  if isinstance(arch, str):
29
36
  if arch not in ARCHS:
30
37
  raise ValueError(f"unknown architecture '{arch}'")
31
38
 
32
- _model = detection.__dict__[arch](assume_straight_pages=assume_straight_pages)
39
+ _model = detection.__dict__[arch](
40
+ assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg
41
+ )
33
42
  else:
34
43
  if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
35
44
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -50,6 +59,8 @@ def _predictor(arch: Any, assume_straight_pages: bool = True, **kwargs: Any) ->
50
59
  def detection_predictor(
51
60
  arch: Any = "fast_base",
52
61
  assume_straight_pages: bool = True,
62
+ load_in_8_bit: bool = False,
63
+ engine_cfg: EngineConfig = EngineConfig(),
53
64
  **kwargs: Any,
54
65
  ) -> DetectionPredictor:
55
66
  """Text detection architecture.
@@ -64,10 +75,12 @@ def detection_predictor(
64
75
  ----
65
76
  arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
66
77
  assume_straight_pages: If True, fit straight boxes to the page
78
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
79
+ engine_cfg: configuration for the inference engine
67
80
  **kwargs: optional keyword arguments passed to the architecture
68
81
 
69
82
  Returns:
70
83
  -------
71
84
  Detection predictor
72
85
  """
73
- return _predictor(arch, assume_straight_pages, **kwargs)
86
+ return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs)
onnxtr/models/engine.py CHANGED
@@ -3,14 +3,79 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Union
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import numpy as np
9
- import onnxruntime
9
+ from onnxruntime import (
10
+ ExecutionMode,
11
+ GraphOptimizationLevel,
12
+ InferenceSession,
13
+ SessionOptions,
14
+ get_available_providers,
15
+ get_device,
16
+ )
10
17
 
11
18
  from onnxtr.utils.data import download_from_url
12
19
  from onnxtr.utils.geometry import shape_translate
13
20
 
21
+ __all__ = ["EngineConfig"]
22
+
23
+
24
+ class EngineConfig:
25
+ """Implements a configuration class for the engine of a model
26
+
27
+ Args:
28
+ ----
29
+ providers: list of providers to use for inference ref.: https://onnxruntime.ai/docs/execution-providers/
30
+ session_options: configuration for the inference session ref.: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ providers: Optional[Union[List[Tuple[str, Dict[str, Any]]], List[str]]] = None,
36
+ session_options: Optional[SessionOptions] = None,
37
+ ):
38
+ self._providers = providers or self._init_providers()
39
+ self._session_options = session_options or self._init_sess_opts()
40
+
41
+ def _init_providers(self) -> List[Tuple[str, Dict[str, Any]]]:
42
+ providers: Any = [("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
43
+ available_providers = get_available_providers()
44
+ if "CUDAExecutionProvider" in available_providers and get_device() == "GPU": # pragma: no cover
45
+ providers.insert(
46
+ 0,
47
+ (
48
+ "CUDAExecutionProvider",
49
+ {
50
+ "device_id": 0,
51
+ "arena_extend_strategy": "kNextPowerOfTwo",
52
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
53
+ "do_copy_in_default_stream": True,
54
+ },
55
+ ),
56
+ )
57
+ return providers
58
+
59
+ def _init_sess_opts(self) -> SessionOptions:
60
+ session_options = SessionOptions()
61
+ session_options.enable_cpu_mem_arena = True
62
+ session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
63
+ session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
64
+ session_options.intra_op_num_threads = -1
65
+ session_options.inter_op_num_threads = -1
66
+ return session_options
67
+
68
+ @property
69
+ def providers(self) -> Union[List[Tuple[str, Dict[str, Any]]], List[str]]:
70
+ return self._providers
71
+
72
+ @property
73
+ def session_options(self) -> SessionOptions:
74
+ return self._session_options
75
+
76
+ def __repr__(self) -> str:
77
+ return f"EngineConfig(providers={self.providers}"
78
+
14
79
 
15
80
  class Engine:
16
81
  """Implements an abstract class for the engine of a model
@@ -18,15 +83,16 @@ class Engine:
18
83
  Args:
19
84
  ----
20
85
  url: the url to use to download a model if needed
21
- providers: list of providers to use for inference
86
+ engine_cfg: the configuration of the engine
22
87
  **kwargs: additional arguments to be passed to `download_from_url`
23
88
  """
24
89
 
25
- def __init__(
26
- self, url: str, providers: List[str] = ["CPUExecutionProvider", "CUDAExecutionProvider"], **kwargs: Any
27
- ) -> None:
90
+ def __init__(self, url: str, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any) -> None:
91
+ engine_cfg = engine_cfg or EngineConfig()
28
92
  archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
29
- self.runtime = onnxruntime.InferenceSession(archive_path, providers=providers)
93
+ self.session_options = engine_cfg.session_options
94
+ self.providers = engine_cfg.providers
95
+ self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options)
30
96
  self.runtime_inputs = self.runtime.get_inputs()[0]
31
97
  self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3
32
98
  self.fixed_batch_size: Union[int, str] = self.runtime_inputs.shape[
@@ -43,8 +109,8 @@ class Engine:
43
109
  inputs = np.broadcast_to(inputs, (self.fixed_batch_size, *inputs.shape))
44
110
  # combine the results
45
111
  logits = np.concatenate(
46
- [self.runtime.run(self.output_name, {"input": batch})[0] for batch in inputs], axis=0
112
+ [self.runtime.run(self.output_name, {self.runtime_inputs.name: batch})[0] for batch in inputs], axis=0
47
113
  )
48
114
  else:
49
- logits = self.runtime.run(self.output_name, {"input": inputs})[0]
115
+ logits = self.runtime.run(self.output_name, {self.runtime_inputs.name: inputs})[0]
50
116
  return shape_translate(logits, format="BHWC")