onnxtr 0.1.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/io/elements.py +17 -4
- onnxtr/io/pdf.py +6 -3
- onnxtr/models/__init__.py +1 -0
- onnxtr/models/_utils.py +57 -20
- onnxtr/models/builder.py +24 -9
- onnxtr/models/classification/models/mobilenet.py +25 -7
- onnxtr/models/classification/predictor/base.py +1 -0
- onnxtr/models/classification/zoo.py +22 -7
- onnxtr/models/detection/_utils/__init__.py +1 -0
- onnxtr/models/detection/_utils/base.py +66 -0
- onnxtr/models/detection/models/differentiable_binarization.py +41 -11
- onnxtr/models/detection/models/fast.py +37 -9
- onnxtr/models/detection/models/linknet.py +39 -9
- onnxtr/models/detection/postprocessor/base.py +4 -3
- onnxtr/models/detection/predictor/base.py +15 -1
- onnxtr/models/detection/zoo.py +16 -3
- onnxtr/models/engine.py +75 -9
- onnxtr/models/predictor/base.py +69 -42
- onnxtr/models/predictor/predictor.py +22 -15
- onnxtr/models/recognition/models/crnn.py +39 -9
- onnxtr/models/recognition/models/master.py +19 -5
- onnxtr/models/recognition/models/parseq.py +20 -5
- onnxtr/models/recognition/models/sar.py +19 -5
- onnxtr/models/recognition/models/vitstr.py +31 -9
- onnxtr/models/recognition/zoo.py +12 -6
- onnxtr/models/zoo.py +22 -0
- onnxtr/py.typed +0 -0
- onnxtr/utils/geometry.py +33 -12
- onnxtr/version.py +1 -1
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/METADATA +81 -16
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/RECORD +35 -32
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/WHEEL +1 -1
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/top_level.txt +0 -1
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/LICENSE +0 -0
- {onnxtr-0.1.2.dist-info → onnxtr-0.3.0.dist-info}/zip-safe +0 -0
|
@@ -8,7 +8,7 @@ from typing import Any, Dict, Optional
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from scipy.special import expit
|
|
10
10
|
|
|
11
|
-
from ...engine import Engine
|
|
11
|
+
from ...engine import Engine, EngineConfig
|
|
12
12
|
from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
13
13
|
|
|
14
14
|
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
|
|
@@ -20,18 +20,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
20
20
|
"mean": (0.798, 0.785, 0.772),
|
|
21
21
|
"std": (0.264, 0.2749, 0.287),
|
|
22
22
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet50-69ba0015.onnx",
|
|
23
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet50_static_8_bit-09a6104f.onnx",
|
|
23
24
|
},
|
|
24
25
|
"db_resnet34": {
|
|
25
26
|
"input_shape": (3, 1024, 1024),
|
|
26
27
|
"mean": (0.798, 0.785, 0.772),
|
|
27
28
|
"std": (0.264, 0.2749, 0.287),
|
|
28
29
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet34-b4873198.onnx",
|
|
30
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet34_static_8_bit-027e2c7f.onnx",
|
|
29
31
|
},
|
|
30
32
|
"db_mobilenet_v3_large": {
|
|
31
33
|
"input_shape": (3, 1024, 1024),
|
|
32
34
|
"mean": (0.798, 0.785, 0.772),
|
|
33
35
|
"std": (0.264, 0.2749, 0.287),
|
|
34
|
-
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0
|
|
36
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large-4987e7bd.onnx",
|
|
37
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large_static_8_bit-535a6f25.onnx",
|
|
35
38
|
},
|
|
36
39
|
}
|
|
37
40
|
|
|
@@ -42,6 +45,7 @@ class DBNet(Engine):
|
|
|
42
45
|
Args:
|
|
43
46
|
----
|
|
44
47
|
model_path: path or url to onnx model file
|
|
48
|
+
engine_cfg: configuration for the inference engine
|
|
45
49
|
bin_thresh: threshold for binarization of the output feature map
|
|
46
50
|
box_thresh: minimal objectness score to consider a box
|
|
47
51
|
assume_straight_pages: if True, fit straight bounding boxes only
|
|
@@ -51,14 +55,15 @@ class DBNet(Engine):
|
|
|
51
55
|
|
|
52
56
|
def __init__(
|
|
53
57
|
self,
|
|
54
|
-
model_path,
|
|
58
|
+
model_path: str,
|
|
59
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
55
60
|
bin_thresh: float = 0.3,
|
|
56
61
|
box_thresh: float = 0.1,
|
|
57
62
|
assume_straight_pages: bool = True,
|
|
58
63
|
cfg: Optional[Dict[str, Any]] = None,
|
|
59
64
|
**kwargs: Any,
|
|
60
65
|
) -> None:
|
|
61
|
-
super().__init__(url=model_path, **kwargs)
|
|
66
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
62
67
|
self.cfg = cfg
|
|
63
68
|
self.assume_straight_pages = assume_straight_pages
|
|
64
69
|
self.postprocessor = GeneralDetectionPostProcessor(
|
|
@@ -87,13 +92,22 @@ class DBNet(Engine):
|
|
|
87
92
|
def _dbnet(
|
|
88
93
|
arch: str,
|
|
89
94
|
model_path: str,
|
|
95
|
+
load_in_8_bit: bool = False,
|
|
96
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
90
97
|
**kwargs: Any,
|
|
91
98
|
) -> DBNet:
|
|
99
|
+
# Patch the url
|
|
100
|
+
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
92
101
|
# Build the model
|
|
93
|
-
return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
|
|
102
|
+
return DBNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
|
|
94
103
|
|
|
95
104
|
|
|
96
|
-
def db_resnet34(
|
|
105
|
+
def db_resnet34(
|
|
106
|
+
model_path: str = default_cfgs["db_resnet34"]["url"],
|
|
107
|
+
load_in_8_bit: bool = False,
|
|
108
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
109
|
+
**kwargs: Any,
|
|
110
|
+
) -> DBNet:
|
|
97
111
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
98
112
|
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
|
|
99
113
|
|
|
@@ -106,16 +120,23 @@ def db_resnet34(model_path: str = default_cfgs["db_resnet34"]["url"], **kwargs:
|
|
|
106
120
|
Args:
|
|
107
121
|
----
|
|
108
122
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
123
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
124
|
+
engine_cfg: configuration for the inference engine
|
|
109
125
|
**kwargs: keyword arguments of the DBNet architecture
|
|
110
126
|
|
|
111
127
|
Returns:
|
|
112
128
|
-------
|
|
113
129
|
text detection architecture
|
|
114
130
|
"""
|
|
115
|
-
return _dbnet("db_resnet34", model_path, **kwargs)
|
|
131
|
+
return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
116
132
|
|
|
117
133
|
|
|
118
|
-
def db_resnet50(
|
|
134
|
+
def db_resnet50(
|
|
135
|
+
model_path: str = default_cfgs["db_resnet50"]["url"],
|
|
136
|
+
load_in_8_bit: bool = False,
|
|
137
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
138
|
+
**kwargs: Any,
|
|
139
|
+
) -> DBNet:
|
|
119
140
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
120
141
|
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
|
|
121
142
|
|
|
@@ -128,16 +149,23 @@ def db_resnet50(model_path: str = default_cfgs["db_resnet50"]["url"], **kwargs:
|
|
|
128
149
|
Args:
|
|
129
150
|
----
|
|
130
151
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
152
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
153
|
+
engine_cfg: configuration for the inference engine
|
|
131
154
|
**kwargs: keyword arguments of the DBNet architecture
|
|
132
155
|
|
|
133
156
|
Returns:
|
|
134
157
|
-------
|
|
135
158
|
text detection architecture
|
|
136
159
|
"""
|
|
137
|
-
return _dbnet("db_resnet50", model_path, **kwargs)
|
|
160
|
+
return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
138
161
|
|
|
139
162
|
|
|
140
|
-
def db_mobilenet_v3_large(
|
|
163
|
+
def db_mobilenet_v3_large(
|
|
164
|
+
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
|
|
165
|
+
load_in_8_bit: bool = False,
|
|
166
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
167
|
+
**kwargs: Any,
|
|
168
|
+
) -> DBNet:
|
|
141
169
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
142
170
|
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
|
|
143
171
|
|
|
@@ -150,10 +178,12 @@ def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"
|
|
|
150
178
|
Args:
|
|
151
179
|
----
|
|
152
180
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
181
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
182
|
+
engine_cfg: configuration for the inference engine
|
|
153
183
|
**kwargs: keyword arguments of the DBNet architecture
|
|
154
184
|
|
|
155
185
|
Returns:
|
|
156
186
|
-------
|
|
157
187
|
text detection architecture
|
|
158
188
|
"""
|
|
159
|
-
return _dbnet("db_mobilenet_v3_large", model_path, **kwargs)
|
|
189
|
+
return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -3,12 +3,13 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
import logging
|
|
6
7
|
from typing import Any, Dict, Optional
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
from scipy.special import expit
|
|
10
11
|
|
|
11
|
-
from ...engine import Engine
|
|
12
|
+
from ...engine import Engine, EngineConfig
|
|
12
13
|
from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
13
14
|
|
|
14
15
|
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
|
|
@@ -42,6 +43,7 @@ class FAST(Engine):
|
|
|
42
43
|
Args:
|
|
43
44
|
----
|
|
44
45
|
model_path: path or url to onnx model file
|
|
46
|
+
engine_cfg: configuration for the inference engine
|
|
45
47
|
bin_thresh: threshold for binarization of the output feature map
|
|
46
48
|
box_thresh: minimal objectness score to consider a box
|
|
47
49
|
assume_straight_pages: if True, fit straight bounding boxes only
|
|
@@ -52,13 +54,14 @@ class FAST(Engine):
|
|
|
52
54
|
def __init__(
|
|
53
55
|
self,
|
|
54
56
|
model_path: str,
|
|
57
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
55
58
|
bin_thresh: float = 0.1,
|
|
56
59
|
box_thresh: float = 0.1,
|
|
57
60
|
assume_straight_pages: bool = True,
|
|
58
61
|
cfg: Optional[Dict[str, Any]] = None,
|
|
59
62
|
**kwargs: Any,
|
|
60
63
|
) -> None:
|
|
61
|
-
super().__init__(url=model_path, **kwargs)
|
|
64
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
62
65
|
self.cfg = cfg
|
|
63
66
|
self.assume_straight_pages = assume_straight_pages
|
|
64
67
|
|
|
@@ -88,13 +91,22 @@ class FAST(Engine):
|
|
|
88
91
|
def _fast(
|
|
89
92
|
arch: str,
|
|
90
93
|
model_path: str,
|
|
94
|
+
load_in_8_bit: bool = False,
|
|
95
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
91
96
|
**kwargs: Any,
|
|
92
97
|
) -> FAST:
|
|
98
|
+
if load_in_8_bit:
|
|
99
|
+
logging.warning("FAST models do not support 8-bit quantization yet. Loading full precision model...")
|
|
93
100
|
# Build the model
|
|
94
|
-
return FAST(model_path, cfg=default_cfgs[arch], **kwargs)
|
|
101
|
+
return FAST(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
|
|
95
102
|
|
|
96
103
|
|
|
97
|
-
def fast_tiny(
|
|
104
|
+
def fast_tiny(
|
|
105
|
+
model_path: str = default_cfgs["fast_tiny"]["url"],
|
|
106
|
+
load_in_8_bit: bool = False,
|
|
107
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
108
|
+
**kwargs: Any,
|
|
109
|
+
) -> FAST:
|
|
98
110
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
99
111
|
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
|
|
100
112
|
|
|
@@ -107,16 +119,23 @@ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], **kwargs: Any)
|
|
|
107
119
|
Args:
|
|
108
120
|
----
|
|
109
121
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
122
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
123
|
+
engine_cfg: configuration for the inference engine
|
|
110
124
|
**kwargs: keyword arguments of the DBNet architecture
|
|
111
125
|
|
|
112
126
|
Returns:
|
|
113
127
|
-------
|
|
114
128
|
text detection architecture
|
|
115
129
|
"""
|
|
116
|
-
return _fast("fast_tiny", model_path, **kwargs)
|
|
130
|
+
return _fast("fast_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
117
131
|
|
|
118
132
|
|
|
119
|
-
def fast_small(
|
|
133
|
+
def fast_small(
|
|
134
|
+
model_path: str = default_cfgs["fast_small"]["url"],
|
|
135
|
+
load_in_8_bit: bool = False,
|
|
136
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
137
|
+
**kwargs: Any,
|
|
138
|
+
) -> FAST:
|
|
120
139
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
121
140
|
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
|
|
122
141
|
|
|
@@ -129,16 +148,23 @@ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], **kwargs: An
|
|
|
129
148
|
Args:
|
|
130
149
|
----
|
|
131
150
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
151
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
152
|
+
engine_cfg: configuration for the inference engine
|
|
132
153
|
**kwargs: keyword arguments of the DBNet architecture
|
|
133
154
|
|
|
134
155
|
Returns:
|
|
135
156
|
-------
|
|
136
157
|
text detection architecture
|
|
137
158
|
"""
|
|
138
|
-
return _fast("fast_small", model_path, **kwargs)
|
|
159
|
+
return _fast("fast_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
139
160
|
|
|
140
161
|
|
|
141
|
-
def fast_base(
|
|
162
|
+
def fast_base(
|
|
163
|
+
model_path: str = default_cfgs["fast_base"]["url"],
|
|
164
|
+
load_in_8_bit: bool = False,
|
|
165
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
166
|
+
**kwargs: Any,
|
|
167
|
+
) -> FAST:
|
|
142
168
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
143
169
|
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
|
|
144
170
|
|
|
@@ -151,10 +177,12 @@ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], **kwargs: Any)
|
|
|
151
177
|
Args:
|
|
152
178
|
----
|
|
153
179
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
180
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
181
|
+
engine_cfg: configuration for the inference engine
|
|
154
182
|
**kwargs: keyword arguments of the DBNet architecture
|
|
155
183
|
|
|
156
184
|
Returns:
|
|
157
185
|
-------
|
|
158
186
|
text detection architecture
|
|
159
187
|
"""
|
|
160
|
-
return _fast("fast_base", model_path, **kwargs)
|
|
188
|
+
return _fast("fast_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -8,7 +8,7 @@ from typing import Any, Dict, Optional
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from scipy.special import expit
|
|
10
10
|
|
|
11
|
-
from ...engine import Engine
|
|
11
|
+
from ...engine import Engine, EngineConfig
|
|
12
12
|
from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
13
13
|
|
|
14
14
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
@@ -20,18 +20,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
20
20
|
"mean": (0.798, 0.785, 0.772),
|
|
21
21
|
"std": (0.264, 0.2749, 0.287),
|
|
22
22
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet18-e0e0b9dc.onnx",
|
|
23
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/linknet_resnet18_static_8_bit-3b3a37dd.onnx",
|
|
23
24
|
},
|
|
24
25
|
"linknet_resnet34": {
|
|
25
26
|
"input_shape": (3, 1024, 1024),
|
|
26
27
|
"mean": (0.798, 0.785, 0.772),
|
|
27
28
|
"std": (0.264, 0.2749, 0.287),
|
|
28
29
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet34-93e39a39.onnx",
|
|
30
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/linknet_resnet34_static_8_bit-2824329d.onnx",
|
|
29
31
|
},
|
|
30
32
|
"linknet_resnet50": {
|
|
31
33
|
"input_shape": (3, 1024, 1024),
|
|
32
34
|
"mean": (0.798, 0.785, 0.772),
|
|
33
35
|
"std": (0.264, 0.2749, 0.287),
|
|
34
36
|
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet50-15d8c4ec.onnx",
|
|
37
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/linknet_resnet50_static_8_bit-65d6b0b8.onnx",
|
|
35
38
|
},
|
|
36
39
|
}
|
|
37
40
|
|
|
@@ -42,6 +45,7 @@ class LinkNet(Engine):
|
|
|
42
45
|
Args:
|
|
43
46
|
----
|
|
44
47
|
model_path: path or url to onnx model file
|
|
48
|
+
engine_cfg: configuration for the inference engine
|
|
45
49
|
bin_thresh: threshold for binarization of the output feature map
|
|
46
50
|
box_thresh: minimal objectness score to consider a box
|
|
47
51
|
assume_straight_pages: if True, fit straight bounding boxes only
|
|
@@ -52,13 +56,14 @@ class LinkNet(Engine):
|
|
|
52
56
|
def __init__(
|
|
53
57
|
self,
|
|
54
58
|
model_path: str,
|
|
59
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
55
60
|
bin_thresh: float = 0.1,
|
|
56
61
|
box_thresh: float = 0.1,
|
|
57
62
|
assume_straight_pages: bool = True,
|
|
58
63
|
cfg: Optional[Dict[str, Any]] = None,
|
|
59
64
|
**kwargs: Any,
|
|
60
65
|
) -> None:
|
|
61
|
-
super().__init__(url=model_path, **kwargs)
|
|
66
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
62
67
|
self.cfg = cfg
|
|
63
68
|
self.assume_straight_pages = assume_straight_pages
|
|
64
69
|
|
|
@@ -88,13 +93,22 @@ class LinkNet(Engine):
|
|
|
88
93
|
def _linknet(
|
|
89
94
|
arch: str,
|
|
90
95
|
model_path: str,
|
|
96
|
+
load_in_8_bit: bool = False,
|
|
97
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
91
98
|
**kwargs: Any,
|
|
92
99
|
) -> LinkNet:
|
|
100
|
+
# Patch the url
|
|
101
|
+
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
93
102
|
# Build the model
|
|
94
|
-
return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs)
|
|
103
|
+
return LinkNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
|
|
95
104
|
|
|
96
105
|
|
|
97
|
-
def linknet_resnet18(
|
|
106
|
+
def linknet_resnet18(
|
|
107
|
+
model_path: str = default_cfgs["linknet_resnet18"]["url"],
|
|
108
|
+
load_in_8_bit: bool = False,
|
|
109
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
110
|
+
**kwargs: Any,
|
|
111
|
+
) -> LinkNet:
|
|
98
112
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
99
113
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
100
114
|
|
|
@@ -107,16 +121,23 @@ def linknet_resnet18(model_path: str = default_cfgs["linknet_resnet18"]["url"],
|
|
|
107
121
|
Args:
|
|
108
122
|
----
|
|
109
123
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
124
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
125
|
+
engine_cfg: configuration for the inference engine
|
|
110
126
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
111
127
|
|
|
112
128
|
Returns:
|
|
113
129
|
-------
|
|
114
130
|
text detection architecture
|
|
115
131
|
"""
|
|
116
|
-
return _linknet("linknet_resnet18", model_path, **kwargs)
|
|
132
|
+
return _linknet("linknet_resnet18", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
117
133
|
|
|
118
134
|
|
|
119
|
-
def linknet_resnet34(
|
|
135
|
+
def linknet_resnet34(
|
|
136
|
+
model_path: str = default_cfgs["linknet_resnet34"]["url"],
|
|
137
|
+
load_in_8_bit: bool = False,
|
|
138
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
139
|
+
**kwargs: Any,
|
|
140
|
+
) -> LinkNet:
|
|
120
141
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
121
142
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
122
143
|
|
|
@@ -129,16 +150,23 @@ def linknet_resnet34(model_path: str = default_cfgs["linknet_resnet34"]["url"],
|
|
|
129
150
|
Args:
|
|
130
151
|
----
|
|
131
152
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
153
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
154
|
+
engine_cfg: configuration for the inference engine
|
|
132
155
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
133
156
|
|
|
134
157
|
Returns:
|
|
135
158
|
-------
|
|
136
159
|
text detection architecture
|
|
137
160
|
"""
|
|
138
|
-
return _linknet("linknet_resnet34", model_path, **kwargs)
|
|
161
|
+
return _linknet("linknet_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
139
162
|
|
|
140
163
|
|
|
141
|
-
def linknet_resnet50(
|
|
164
|
+
def linknet_resnet50(
|
|
165
|
+
model_path: str = default_cfgs["linknet_resnet50"]["url"],
|
|
166
|
+
load_in_8_bit: bool = False,
|
|
167
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
168
|
+
**kwargs: Any,
|
|
169
|
+
) -> LinkNet:
|
|
142
170
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
143
171
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
144
172
|
|
|
@@ -151,10 +179,12 @@ def linknet_resnet50(model_path: str = default_cfgs["linknet_resnet50"]["url"],
|
|
|
151
179
|
Args:
|
|
152
180
|
----
|
|
153
181
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
182
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
183
|
+
engine_cfg: configuration for the inference engine
|
|
154
184
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
155
185
|
|
|
156
186
|
Returns:
|
|
157
187
|
-------
|
|
158
188
|
text detection architecture
|
|
159
189
|
"""
|
|
160
|
-
return _linknet("linknet_resnet50", model_path, **kwargs)
|
|
190
|
+
return _linknet("linknet_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -109,7 +109,7 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
|
|
|
109
109
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
110
110
|
for contour in contours:
|
|
111
111
|
# Check whether smallest enclosing bounding box is not too small
|
|
112
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
112
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
|
|
113
113
|
continue
|
|
114
114
|
# Compute objectness
|
|
115
115
|
if self.assume_straight_pages:
|
|
@@ -136,9 +136,10 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
|
|
|
136
136
|
# compute relative box to get rid of img shape
|
|
137
137
|
_box[:, 0] /= width
|
|
138
138
|
_box[:, 1] /= height
|
|
139
|
-
|
|
139
|
+
# Add score to box as (0, score)
|
|
140
|
+
boxes.append(np.vstack([_box, np.array([0.0, score])]))
|
|
140
141
|
|
|
141
142
|
if not self.assume_straight_pages:
|
|
142
|
-
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0,
|
|
143
|
+
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
|
|
143
144
|
else:
|
|
144
145
|
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
|
|
@@ -7,6 +7,7 @@ from typing import Any, List, Tuple, Union
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
|
+
from onnxtr.models.detection._utils import _remove_padding
|
|
10
11
|
from onnxtr.models.preprocessor import PreProcessor
|
|
11
12
|
from onnxtr.utils.repr import NestedObject
|
|
12
13
|
|
|
@@ -38,6 +39,11 @@ class DetectionPredictor(NestedObject):
|
|
|
38
39
|
return_maps: bool = False,
|
|
39
40
|
**kwargs: Any,
|
|
40
41
|
) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]:
|
|
42
|
+
# Extract parameters from the preprocessor
|
|
43
|
+
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
|
|
44
|
+
symmetric_pad = self.pre_processor.resize.symmetric_pad
|
|
45
|
+
assume_straight_pages = self.model.assume_straight_pages
|
|
46
|
+
|
|
41
47
|
# Dimension check
|
|
42
48
|
if any(page.ndim != 3 for page in pages):
|
|
43
49
|
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
@@ -47,7 +53,15 @@ class DetectionPredictor(NestedObject):
|
|
|
47
53
|
self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
|
|
48
54
|
]
|
|
49
55
|
|
|
50
|
-
|
|
56
|
+
# Remove padding from loc predictions
|
|
57
|
+
preds = _remove_padding(
|
|
58
|
+
pages,
|
|
59
|
+
[pred[0] for batch in predicted_batches for pred in batch["preds"]],
|
|
60
|
+
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
61
|
+
symmetric_pad=symmetric_pad,
|
|
62
|
+
assume_straight_pages=assume_straight_pages,
|
|
63
|
+
)
|
|
64
|
+
|
|
51
65
|
if return_maps:
|
|
52
66
|
seg_maps = [pred for batch in predicted_batches for pred in batch["out_map"]]
|
|
53
67
|
return preds, seg_maps
|
onnxtr/models/detection/zoo.py
CHANGED
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
from .. import detection
|
|
9
|
+
from ..engine import EngineConfig
|
|
9
10
|
from ..preprocessor import PreProcessor
|
|
10
11
|
from .predictor import DetectionPredictor
|
|
11
12
|
|
|
@@ -24,12 +25,20 @@ ARCHS = [
|
|
|
24
25
|
]
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
def _predictor(
|
|
28
|
+
def _predictor(
|
|
29
|
+
arch: Any,
|
|
30
|
+
assume_straight_pages: bool = True,
|
|
31
|
+
load_in_8_bit: bool = False,
|
|
32
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
33
|
+
**kwargs: Any,
|
|
34
|
+
) -> DetectionPredictor:
|
|
28
35
|
if isinstance(arch, str):
|
|
29
36
|
if arch not in ARCHS:
|
|
30
37
|
raise ValueError(f"unknown architecture '{arch}'")
|
|
31
38
|
|
|
32
|
-
_model = detection.__dict__[arch](
|
|
39
|
+
_model = detection.__dict__[arch](
|
|
40
|
+
assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg
|
|
41
|
+
)
|
|
33
42
|
else:
|
|
34
43
|
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
|
|
35
44
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -50,6 +59,8 @@ def _predictor(arch: Any, assume_straight_pages: bool = True, **kwargs: Any) ->
|
|
|
50
59
|
def detection_predictor(
|
|
51
60
|
arch: Any = "fast_base",
|
|
52
61
|
assume_straight_pages: bool = True,
|
|
62
|
+
load_in_8_bit: bool = False,
|
|
63
|
+
engine_cfg: EngineConfig = EngineConfig(),
|
|
53
64
|
**kwargs: Any,
|
|
54
65
|
) -> DetectionPredictor:
|
|
55
66
|
"""Text detection architecture.
|
|
@@ -64,10 +75,12 @@ def detection_predictor(
|
|
|
64
75
|
----
|
|
65
76
|
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
|
|
66
77
|
assume_straight_pages: If True, fit straight boxes to the page
|
|
78
|
+
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
79
|
+
engine_cfg: configuration for the inference engine
|
|
67
80
|
**kwargs: optional keyword arguments passed to the architecture
|
|
68
81
|
|
|
69
82
|
Returns:
|
|
70
83
|
-------
|
|
71
84
|
Detection predictor
|
|
72
85
|
"""
|
|
73
|
-
return _predictor(arch, assume_straight_pages, **kwargs)
|
|
86
|
+
return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs)
|
onnxtr/models/engine.py
CHANGED
|
@@ -3,14 +3,79 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any, List, Union
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
|
-
import
|
|
9
|
+
from onnxruntime import (
|
|
10
|
+
ExecutionMode,
|
|
11
|
+
GraphOptimizationLevel,
|
|
12
|
+
InferenceSession,
|
|
13
|
+
SessionOptions,
|
|
14
|
+
get_available_providers,
|
|
15
|
+
get_device,
|
|
16
|
+
)
|
|
10
17
|
|
|
11
18
|
from onnxtr.utils.data import download_from_url
|
|
12
19
|
from onnxtr.utils.geometry import shape_translate
|
|
13
20
|
|
|
21
|
+
__all__ = ["EngineConfig"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EngineConfig:
|
|
25
|
+
"""Implements a configuration class for the engine of a model
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
----
|
|
29
|
+
providers: list of providers to use for inference ref.: https://onnxruntime.ai/docs/execution-providers/
|
|
30
|
+
session_options: configuration for the inference session ref.: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
providers: Optional[Union[List[Tuple[str, Dict[str, Any]]], List[str]]] = None,
|
|
36
|
+
session_options: Optional[SessionOptions] = None,
|
|
37
|
+
):
|
|
38
|
+
self._providers = providers or self._init_providers()
|
|
39
|
+
self._session_options = session_options or self._init_sess_opts()
|
|
40
|
+
|
|
41
|
+
def _init_providers(self) -> List[Tuple[str, Dict[str, Any]]]:
|
|
42
|
+
providers: Any = [("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
|
|
43
|
+
available_providers = get_available_providers()
|
|
44
|
+
if "CUDAExecutionProvider" in available_providers and get_device() == "GPU": # pragma: no cover
|
|
45
|
+
providers.insert(
|
|
46
|
+
0,
|
|
47
|
+
(
|
|
48
|
+
"CUDAExecutionProvider",
|
|
49
|
+
{
|
|
50
|
+
"device_id": 0,
|
|
51
|
+
"arena_extend_strategy": "kNextPowerOfTwo",
|
|
52
|
+
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
|
53
|
+
"do_copy_in_default_stream": True,
|
|
54
|
+
},
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
return providers
|
|
58
|
+
|
|
59
|
+
def _init_sess_opts(self) -> SessionOptions:
|
|
60
|
+
session_options = SessionOptions()
|
|
61
|
+
session_options.enable_cpu_mem_arena = True
|
|
62
|
+
session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
|
63
|
+
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
64
|
+
session_options.intra_op_num_threads = -1
|
|
65
|
+
session_options.inter_op_num_threads = -1
|
|
66
|
+
return session_options
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def providers(self) -> Union[List[Tuple[str, Dict[str, Any]]], List[str]]:
|
|
70
|
+
return self._providers
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def session_options(self) -> SessionOptions:
|
|
74
|
+
return self._session_options
|
|
75
|
+
|
|
76
|
+
def __repr__(self) -> str:
|
|
77
|
+
return f"EngineConfig(providers={self.providers}"
|
|
78
|
+
|
|
14
79
|
|
|
15
80
|
class Engine:
|
|
16
81
|
"""Implements an abstract class for the engine of a model
|
|
@@ -18,15 +83,16 @@ class Engine:
|
|
|
18
83
|
Args:
|
|
19
84
|
----
|
|
20
85
|
url: the url to use to download a model if needed
|
|
21
|
-
|
|
86
|
+
engine_cfg: the configuration of the engine
|
|
22
87
|
**kwargs: additional arguments to be passed to `download_from_url`
|
|
23
88
|
"""
|
|
24
89
|
|
|
25
|
-
def __init__(
|
|
26
|
-
|
|
27
|
-
) -> None:
|
|
90
|
+
def __init__(self, url: str, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any) -> None:
|
|
91
|
+
engine_cfg = engine_cfg or EngineConfig()
|
|
28
92
|
archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
|
|
29
|
-
self.
|
|
93
|
+
self.session_options = engine_cfg.session_options
|
|
94
|
+
self.providers = engine_cfg.providers
|
|
95
|
+
self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options)
|
|
30
96
|
self.runtime_inputs = self.runtime.get_inputs()[0]
|
|
31
97
|
self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3
|
|
32
98
|
self.fixed_batch_size: Union[int, str] = self.runtime_inputs.shape[
|
|
@@ -43,8 +109,8 @@ class Engine:
|
|
|
43
109
|
inputs = np.broadcast_to(inputs, (self.fixed_batch_size, *inputs.shape))
|
|
44
110
|
# combine the results
|
|
45
111
|
logits = np.concatenate(
|
|
46
|
-
[self.runtime.run(self.output_name, {
|
|
112
|
+
[self.runtime.run(self.output_name, {self.runtime_inputs.name: batch})[0] for batch in inputs], axis=0
|
|
47
113
|
)
|
|
48
114
|
else:
|
|
49
|
-
logits = self.runtime.run(self.output_name, {
|
|
115
|
+
logits = self.runtime.run(self.output_name, {self.runtime_inputs.name: inputs})[0]
|
|
50
116
|
return shape_translate(logits, format="BHWC")
|