onnxtr 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/contrib/base.py +1 -4
- onnxtr/io/elements.py +17 -4
- onnxtr/io/pdf.py +6 -3
- onnxtr/models/__init__.py +1 -0
- onnxtr/models/_utils.py +57 -20
- onnxtr/models/builder.py +24 -9
- onnxtr/models/classification/models/mobilenet.py +12 -5
- onnxtr/models/classification/zoo.py +20 -8
- onnxtr/models/detection/_utils/__init__.py +1 -0
- onnxtr/models/detection/_utils/base.py +66 -0
- onnxtr/models/detection/models/differentiable_binarization.py +27 -12
- onnxtr/models/detection/models/fast.py +30 -9
- onnxtr/models/detection/models/linknet.py +24 -9
- onnxtr/models/detection/postprocessor/base.py +4 -3
- onnxtr/models/detection/predictor/base.py +15 -1
- onnxtr/models/detection/zoo.py +14 -5
- onnxtr/models/engine.py +73 -7
- onnxtr/models/predictor/base.py +65 -42
- onnxtr/models/predictor/predictor.py +23 -16
- onnxtr/models/recognition/models/crnn.py +24 -9
- onnxtr/models/recognition/models/master.py +14 -5
- onnxtr/models/recognition/models/parseq.py +14 -5
- onnxtr/models/recognition/models/sar.py +12 -5
- onnxtr/models/recognition/models/vitstr.py +18 -7
- onnxtr/models/recognition/zoo.py +10 -7
- onnxtr/models/zoo.py +19 -3
- onnxtr/py.typed +0 -0
- onnxtr/utils/geometry.py +33 -12
- onnxtr/version.py +1 -1
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/METADATA +63 -24
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/RECORD +35 -32
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/WHEEL +1 -1
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/top_level.txt +0 -1
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/LICENSE +0 -0
- {onnxtr-0.2.0.dist-info → onnxtr-0.3.1.dist-info}/zip-safe +0 -0
|
@@ -8,7 +8,7 @@ from typing import Any, Dict, Optional
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from scipy.special import expit
|
|
10
10
|
|
|
11
|
-
from ...engine import Engine
|
|
11
|
+
from ...engine import Engine, EngineConfig
|
|
12
12
|
from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
13
13
|
|
|
14
14
|
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
|
|
@@ -33,8 +33,8 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
33
33
|
"input_shape": (3, 1024, 1024),
|
|
34
34
|
"mean": (0.798, 0.785, 0.772),
|
|
35
35
|
"std": (0.264, 0.2749, 0.287),
|
|
36
|
-
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0
|
|
37
|
-
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.
|
|
36
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large-4987e7bd.onnx",
|
|
37
|
+
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large_static_8_bit-535a6f25.onnx",
|
|
38
38
|
},
|
|
39
39
|
}
|
|
40
40
|
|
|
@@ -45,6 +45,7 @@ class DBNet(Engine):
|
|
|
45
45
|
Args:
|
|
46
46
|
----
|
|
47
47
|
model_path: path or url to onnx model file
|
|
48
|
+
engine_cfg: configuration for the inference engine
|
|
48
49
|
bin_thresh: threshold for binarization of the output feature map
|
|
49
50
|
box_thresh: minimal objectness score to consider a box
|
|
50
51
|
assume_straight_pages: if True, fit straight bounding boxes only
|
|
@@ -54,14 +55,15 @@ class DBNet(Engine):
|
|
|
54
55
|
|
|
55
56
|
def __init__(
|
|
56
57
|
self,
|
|
57
|
-
model_path,
|
|
58
|
+
model_path: str,
|
|
59
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
58
60
|
bin_thresh: float = 0.3,
|
|
59
61
|
box_thresh: float = 0.1,
|
|
60
62
|
assume_straight_pages: bool = True,
|
|
61
63
|
cfg: Optional[Dict[str, Any]] = None,
|
|
62
64
|
**kwargs: Any,
|
|
63
65
|
) -> None:
|
|
64
|
-
super().__init__(url=model_path, **kwargs)
|
|
66
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
65
67
|
self.cfg = cfg
|
|
66
68
|
self.assume_straight_pages = assume_straight_pages
|
|
67
69
|
self.postprocessor = GeneralDetectionPostProcessor(
|
|
@@ -91,16 +93,20 @@ def _dbnet(
|
|
|
91
93
|
arch: str,
|
|
92
94
|
model_path: str,
|
|
93
95
|
load_in_8_bit: bool = False,
|
|
96
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
94
97
|
**kwargs: Any,
|
|
95
98
|
) -> DBNet:
|
|
96
99
|
# Patch the url
|
|
97
100
|
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
98
101
|
# Build the model
|
|
99
|
-
return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
|
|
102
|
+
return DBNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
|
|
100
103
|
|
|
101
104
|
|
|
102
105
|
def db_resnet34(
|
|
103
|
-
model_path: str = default_cfgs["db_resnet34"]["url"],
|
|
106
|
+
model_path: str = default_cfgs["db_resnet34"]["url"],
|
|
107
|
+
load_in_8_bit: bool = False,
|
|
108
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
109
|
+
**kwargs: Any,
|
|
104
110
|
) -> DBNet:
|
|
105
111
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
106
112
|
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
|
|
@@ -115,17 +121,21 @@ def db_resnet34(
|
|
|
115
121
|
----
|
|
116
122
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
117
123
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
124
|
+
engine_cfg: configuration for the inference engine
|
|
118
125
|
**kwargs: keyword arguments of the DBNet architecture
|
|
119
126
|
|
|
120
127
|
Returns:
|
|
121
128
|
-------
|
|
122
129
|
text detection architecture
|
|
123
130
|
"""
|
|
124
|
-
return _dbnet("db_resnet34", model_path, load_in_8_bit, **kwargs)
|
|
131
|
+
return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
125
132
|
|
|
126
133
|
|
|
127
134
|
def db_resnet50(
|
|
128
|
-
model_path: str = default_cfgs["db_resnet50"]["url"],
|
|
135
|
+
model_path: str = default_cfgs["db_resnet50"]["url"],
|
|
136
|
+
load_in_8_bit: bool = False,
|
|
137
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
138
|
+
**kwargs: Any,
|
|
129
139
|
) -> DBNet:
|
|
130
140
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
131
141
|
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
|
|
@@ -140,17 +150,21 @@ def db_resnet50(
|
|
|
140
150
|
----
|
|
141
151
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
142
152
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
153
|
+
engine_cfg: configuration for the inference engine
|
|
143
154
|
**kwargs: keyword arguments of the DBNet architecture
|
|
144
155
|
|
|
145
156
|
Returns:
|
|
146
157
|
-------
|
|
147
158
|
text detection architecture
|
|
148
159
|
"""
|
|
149
|
-
return _dbnet("db_resnet50", model_path, load_in_8_bit, **kwargs)
|
|
160
|
+
return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
150
161
|
|
|
151
162
|
|
|
152
163
|
def db_mobilenet_v3_large(
|
|
153
|
-
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
|
|
164
|
+
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
|
|
165
|
+
load_in_8_bit: bool = False,
|
|
166
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
167
|
+
**kwargs: Any,
|
|
154
168
|
) -> DBNet:
|
|
155
169
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
156
170
|
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
|
|
@@ -165,10 +179,11 @@ def db_mobilenet_v3_large(
|
|
|
165
179
|
----
|
|
166
180
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
167
181
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
182
|
+
engine_cfg: configuration for the inference engine
|
|
168
183
|
**kwargs: keyword arguments of the DBNet architecture
|
|
169
184
|
|
|
170
185
|
Returns:
|
|
171
186
|
-------
|
|
172
187
|
text detection architecture
|
|
173
188
|
"""
|
|
174
|
-
return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
|
|
189
|
+
return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -9,7 +9,7 @@ from typing import Any, Dict, Optional
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from scipy.special import expit
|
|
11
11
|
|
|
12
|
-
from ...engine import Engine
|
|
12
|
+
from ...engine import Engine, EngineConfig
|
|
13
13
|
from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
14
14
|
|
|
15
15
|
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
|
|
@@ -43,6 +43,7 @@ class FAST(Engine):
|
|
|
43
43
|
Args:
|
|
44
44
|
----
|
|
45
45
|
model_path: path or url to onnx model file
|
|
46
|
+
engine_cfg: configuration for the inference engine
|
|
46
47
|
bin_thresh: threshold for binarization of the output feature map
|
|
47
48
|
box_thresh: minimal objectness score to consider a box
|
|
48
49
|
assume_straight_pages: if True, fit straight bounding boxes only
|
|
@@ -53,13 +54,14 @@ class FAST(Engine):
|
|
|
53
54
|
def __init__(
|
|
54
55
|
self,
|
|
55
56
|
model_path: str,
|
|
57
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
56
58
|
bin_thresh: float = 0.1,
|
|
57
59
|
box_thresh: float = 0.1,
|
|
58
60
|
assume_straight_pages: bool = True,
|
|
59
61
|
cfg: Optional[Dict[str, Any]] = None,
|
|
60
62
|
**kwargs: Any,
|
|
61
63
|
) -> None:
|
|
62
|
-
super().__init__(url=model_path, **kwargs)
|
|
64
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
63
65
|
self.cfg = cfg
|
|
64
66
|
self.assume_straight_pages = assume_straight_pages
|
|
65
67
|
|
|
@@ -90,15 +92,21 @@ def _fast(
|
|
|
90
92
|
arch: str,
|
|
91
93
|
model_path: str,
|
|
92
94
|
load_in_8_bit: bool = False,
|
|
95
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
93
96
|
**kwargs: Any,
|
|
94
97
|
) -> FAST:
|
|
95
98
|
if load_in_8_bit:
|
|
96
99
|
logging.warning("FAST models do not support 8-bit quantization yet. Loading full precision model...")
|
|
97
100
|
# Build the model
|
|
98
|
-
return FAST(model_path, cfg=default_cfgs[arch], **kwargs)
|
|
101
|
+
return FAST(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
|
|
99
102
|
|
|
100
103
|
|
|
101
|
-
def fast_tiny(
|
|
104
|
+
def fast_tiny(
|
|
105
|
+
model_path: str = default_cfgs["fast_tiny"]["url"],
|
|
106
|
+
load_in_8_bit: bool = False,
|
|
107
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
108
|
+
**kwargs: Any,
|
|
109
|
+
) -> FAST:
|
|
102
110
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
103
111
|
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
|
|
104
112
|
|
|
@@ -112,16 +120,22 @@ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit:
|
|
|
112
120
|
----
|
|
113
121
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
114
122
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
123
|
+
engine_cfg: configuration for the inference engine
|
|
115
124
|
**kwargs: keyword arguments of the DBNet architecture
|
|
116
125
|
|
|
117
126
|
Returns:
|
|
118
127
|
-------
|
|
119
128
|
text detection architecture
|
|
120
129
|
"""
|
|
121
|
-
return _fast("fast_tiny", model_path, load_in_8_bit, **kwargs)
|
|
130
|
+
return _fast("fast_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
122
131
|
|
|
123
132
|
|
|
124
|
-
def fast_small(
|
|
133
|
+
def fast_small(
|
|
134
|
+
model_path: str = default_cfgs["fast_small"]["url"],
|
|
135
|
+
load_in_8_bit: bool = False,
|
|
136
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
137
|
+
**kwargs: Any,
|
|
138
|
+
) -> FAST:
|
|
125
139
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
126
140
|
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
|
|
127
141
|
|
|
@@ -135,16 +149,22 @@ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bi
|
|
|
135
149
|
----
|
|
136
150
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
137
151
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
152
|
+
engine_cfg: configuration for the inference engine
|
|
138
153
|
**kwargs: keyword arguments of the DBNet architecture
|
|
139
154
|
|
|
140
155
|
Returns:
|
|
141
156
|
-------
|
|
142
157
|
text detection architecture
|
|
143
158
|
"""
|
|
144
|
-
return _fast("fast_small", model_path, load_in_8_bit, **kwargs)
|
|
159
|
+
return _fast("fast_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
145
160
|
|
|
146
161
|
|
|
147
|
-
def fast_base(
|
|
162
|
+
def fast_base(
|
|
163
|
+
model_path: str = default_cfgs["fast_base"]["url"],
|
|
164
|
+
load_in_8_bit: bool = False,
|
|
165
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
166
|
+
**kwargs: Any,
|
|
167
|
+
) -> FAST:
|
|
148
168
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
149
169
|
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
|
|
150
170
|
|
|
@@ -158,10 +178,11 @@ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit:
|
|
|
158
178
|
----
|
|
159
179
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
160
180
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
181
|
+
engine_cfg: configuration for the inference engine
|
|
161
182
|
**kwargs: keyword arguments of the DBNet architecture
|
|
162
183
|
|
|
163
184
|
Returns:
|
|
164
185
|
-------
|
|
165
186
|
text detection architecture
|
|
166
187
|
"""
|
|
167
|
-
return _fast("fast_base", model_path, load_in_8_bit, **kwargs)
|
|
188
|
+
return _fast("fast_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -8,7 +8,7 @@ from typing import Any, Dict, Optional
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from scipy.special import expit
|
|
10
10
|
|
|
11
|
-
from ...engine import Engine
|
|
11
|
+
from ...engine import Engine, EngineConfig
|
|
12
12
|
from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
13
13
|
|
|
14
14
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
@@ -45,6 +45,7 @@ class LinkNet(Engine):
|
|
|
45
45
|
Args:
|
|
46
46
|
----
|
|
47
47
|
model_path: path or url to onnx model file
|
|
48
|
+
engine_cfg: configuration for the inference engine
|
|
48
49
|
bin_thresh: threshold for binarization of the output feature map
|
|
49
50
|
box_thresh: minimal objectness score to consider a box
|
|
50
51
|
assume_straight_pages: if True, fit straight bounding boxes only
|
|
@@ -55,13 +56,14 @@ class LinkNet(Engine):
|
|
|
55
56
|
def __init__(
|
|
56
57
|
self,
|
|
57
58
|
model_path: str,
|
|
59
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
58
60
|
bin_thresh: float = 0.1,
|
|
59
61
|
box_thresh: float = 0.1,
|
|
60
62
|
assume_straight_pages: bool = True,
|
|
61
63
|
cfg: Optional[Dict[str, Any]] = None,
|
|
62
64
|
**kwargs: Any,
|
|
63
65
|
) -> None:
|
|
64
|
-
super().__init__(url=model_path, **kwargs)
|
|
66
|
+
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
65
67
|
self.cfg = cfg
|
|
66
68
|
self.assume_straight_pages = assume_straight_pages
|
|
67
69
|
|
|
@@ -92,16 +94,20 @@ def _linknet(
|
|
|
92
94
|
arch: str,
|
|
93
95
|
model_path: str,
|
|
94
96
|
load_in_8_bit: bool = False,
|
|
97
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
95
98
|
**kwargs: Any,
|
|
96
99
|
) -> LinkNet:
|
|
97
100
|
# Patch the url
|
|
98
101
|
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
|
|
99
102
|
# Build the model
|
|
100
|
-
return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs)
|
|
103
|
+
return LinkNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
|
|
101
104
|
|
|
102
105
|
|
|
103
106
|
def linknet_resnet18(
|
|
104
|
-
model_path: str = default_cfgs["linknet_resnet18"]["url"],
|
|
107
|
+
model_path: str = default_cfgs["linknet_resnet18"]["url"],
|
|
108
|
+
load_in_8_bit: bool = False,
|
|
109
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
110
|
+
**kwargs: Any,
|
|
105
111
|
) -> LinkNet:
|
|
106
112
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
107
113
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
@@ -116,17 +122,21 @@ def linknet_resnet18(
|
|
|
116
122
|
----
|
|
117
123
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
118
124
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
125
|
+
engine_cfg: configuration for the inference engine
|
|
119
126
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
120
127
|
|
|
121
128
|
Returns:
|
|
122
129
|
-------
|
|
123
130
|
text detection architecture
|
|
124
131
|
"""
|
|
125
|
-
return _linknet("linknet_resnet18", model_path, load_in_8_bit, **kwargs)
|
|
132
|
+
return _linknet("linknet_resnet18", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
126
133
|
|
|
127
134
|
|
|
128
135
|
def linknet_resnet34(
|
|
129
|
-
model_path: str = default_cfgs["linknet_resnet34"]["url"],
|
|
136
|
+
model_path: str = default_cfgs["linknet_resnet34"]["url"],
|
|
137
|
+
load_in_8_bit: bool = False,
|
|
138
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
139
|
+
**kwargs: Any,
|
|
130
140
|
) -> LinkNet:
|
|
131
141
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
132
142
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
@@ -141,17 +151,21 @@ def linknet_resnet34(
|
|
|
141
151
|
----
|
|
142
152
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
143
153
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
154
|
+
engine_cfg: configuration for the inference engine
|
|
144
155
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
145
156
|
|
|
146
157
|
Returns:
|
|
147
158
|
-------
|
|
148
159
|
text detection architecture
|
|
149
160
|
"""
|
|
150
|
-
return _linknet("linknet_resnet34", model_path, load_in_8_bit, **kwargs)
|
|
161
|
+
return _linknet("linknet_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
151
162
|
|
|
152
163
|
|
|
153
164
|
def linknet_resnet50(
|
|
154
|
-
model_path: str = default_cfgs["linknet_resnet50"]["url"],
|
|
165
|
+
model_path: str = default_cfgs["linknet_resnet50"]["url"],
|
|
166
|
+
load_in_8_bit: bool = False,
|
|
167
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
168
|
+
**kwargs: Any,
|
|
155
169
|
) -> LinkNet:
|
|
156
170
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
157
171
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
@@ -166,10 +180,11 @@ def linknet_resnet50(
|
|
|
166
180
|
----
|
|
167
181
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
168
182
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
183
|
+
engine_cfg: configuration for the inference engine
|
|
169
184
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
170
185
|
|
|
171
186
|
Returns:
|
|
172
187
|
-------
|
|
173
188
|
text detection architecture
|
|
174
189
|
"""
|
|
175
|
-
return _linknet("linknet_resnet50", model_path, load_in_8_bit, **kwargs)
|
|
190
|
+
return _linknet("linknet_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -109,7 +109,7 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
|
|
|
109
109
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
110
110
|
for contour in contours:
|
|
111
111
|
# Check whether smallest enclosing bounding box is not too small
|
|
112
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
112
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
|
|
113
113
|
continue
|
|
114
114
|
# Compute objectness
|
|
115
115
|
if self.assume_straight_pages:
|
|
@@ -136,9 +136,10 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
|
|
|
136
136
|
# compute relative box to get rid of img shape
|
|
137
137
|
_box[:, 0] /= width
|
|
138
138
|
_box[:, 1] /= height
|
|
139
|
-
|
|
139
|
+
# Add score to box as (0, score)
|
|
140
|
+
boxes.append(np.vstack([_box, np.array([0.0, score])]))
|
|
140
141
|
|
|
141
142
|
if not self.assume_straight_pages:
|
|
142
|
-
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0,
|
|
143
|
+
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
|
|
143
144
|
else:
|
|
144
145
|
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
|
|
@@ -7,6 +7,7 @@ from typing import Any, List, Tuple, Union
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
|
+
from onnxtr.models.detection._utils import _remove_padding
|
|
10
11
|
from onnxtr.models.preprocessor import PreProcessor
|
|
11
12
|
from onnxtr.utils.repr import NestedObject
|
|
12
13
|
|
|
@@ -38,6 +39,11 @@ class DetectionPredictor(NestedObject):
|
|
|
38
39
|
return_maps: bool = False,
|
|
39
40
|
**kwargs: Any,
|
|
40
41
|
) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]:
|
|
42
|
+
# Extract parameters from the preprocessor
|
|
43
|
+
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
|
|
44
|
+
symmetric_pad = self.pre_processor.resize.symmetric_pad
|
|
45
|
+
assume_straight_pages = self.model.assume_straight_pages
|
|
46
|
+
|
|
41
47
|
# Dimension check
|
|
42
48
|
if any(page.ndim != 3 for page in pages):
|
|
43
49
|
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
@@ -47,7 +53,15 @@ class DetectionPredictor(NestedObject):
|
|
|
47
53
|
self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
|
|
48
54
|
]
|
|
49
55
|
|
|
50
|
-
|
|
56
|
+
# Remove padding from loc predictions
|
|
57
|
+
preds = _remove_padding(
|
|
58
|
+
pages,
|
|
59
|
+
[pred[0] for batch in predicted_batches for pred in batch["preds"]],
|
|
60
|
+
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
61
|
+
symmetric_pad=symmetric_pad,
|
|
62
|
+
assume_straight_pages=assume_straight_pages,
|
|
63
|
+
)
|
|
64
|
+
|
|
51
65
|
if return_maps:
|
|
52
66
|
seg_maps = [pred for batch in predicted_batches for pred in batch["out_map"]]
|
|
53
67
|
return preds, seg_maps
|
onnxtr/models/detection/zoo.py
CHANGED
|
@@ -3,9 +3,10 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, Optional
|
|
7
7
|
|
|
8
8
|
from .. import detection
|
|
9
|
+
from ..engine import EngineConfig
|
|
9
10
|
from ..preprocessor import PreProcessor
|
|
10
11
|
from .predictor import DetectionPredictor
|
|
11
12
|
|
|
@@ -25,13 +26,19 @@ ARCHS = [
|
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def _predictor(
|
|
28
|
-
arch: Any,
|
|
29
|
+
arch: Any,
|
|
30
|
+
assume_straight_pages: bool = True,
|
|
31
|
+
load_in_8_bit: bool = False,
|
|
32
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
33
|
+
**kwargs: Any,
|
|
29
34
|
) -> DetectionPredictor:
|
|
30
35
|
if isinstance(arch, str):
|
|
31
36
|
if arch not in ARCHS:
|
|
32
37
|
raise ValueError(f"unknown architecture '{arch}'")
|
|
33
38
|
|
|
34
|
-
_model = detection.__dict__[arch](
|
|
39
|
+
_model = detection.__dict__[arch](
|
|
40
|
+
assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg
|
|
41
|
+
)
|
|
35
42
|
else:
|
|
36
43
|
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
|
|
37
44
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -41,7 +48,7 @@ def _predictor(
|
|
|
41
48
|
|
|
42
49
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
43
50
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
44
|
-
kwargs["batch_size"] = kwargs.get("batch_size",
|
|
51
|
+
kwargs["batch_size"] = kwargs.get("batch_size", 2)
|
|
45
52
|
predictor = DetectionPredictor(
|
|
46
53
|
PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
|
|
47
54
|
_model,
|
|
@@ -53,6 +60,7 @@ def detection_predictor(
|
|
|
53
60
|
arch: Any = "fast_base",
|
|
54
61
|
assume_straight_pages: bool = True,
|
|
55
62
|
load_in_8_bit: bool = False,
|
|
63
|
+
engine_cfg: Optional[EngineConfig] = None,
|
|
56
64
|
**kwargs: Any,
|
|
57
65
|
) -> DetectionPredictor:
|
|
58
66
|
"""Text detection architecture.
|
|
@@ -68,10 +76,11 @@ def detection_predictor(
|
|
|
68
76
|
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
|
|
69
77
|
assume_straight_pages: If True, fit straight boxes to the page
|
|
70
78
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
79
|
+
engine_cfg: configuration for the inference engine
|
|
71
80
|
**kwargs: optional keyword arguments passed to the architecture
|
|
72
81
|
|
|
73
82
|
Returns:
|
|
74
83
|
-------
|
|
75
84
|
Detection predictor
|
|
76
85
|
"""
|
|
77
|
-
return _predictor(arch, assume_straight_pages, load_in_8_bit, **kwargs)
|
|
86
|
+
return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs)
|
onnxtr/models/engine.py
CHANGED
|
@@ -3,14 +3,79 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any, List, Union
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
|
-
import
|
|
9
|
+
from onnxruntime import (
|
|
10
|
+
ExecutionMode,
|
|
11
|
+
GraphOptimizationLevel,
|
|
12
|
+
InferenceSession,
|
|
13
|
+
SessionOptions,
|
|
14
|
+
get_available_providers,
|
|
15
|
+
get_device,
|
|
16
|
+
)
|
|
10
17
|
|
|
11
18
|
from onnxtr.utils.data import download_from_url
|
|
12
19
|
from onnxtr.utils.geometry import shape_translate
|
|
13
20
|
|
|
21
|
+
__all__ = ["EngineConfig"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EngineConfig:
|
|
25
|
+
"""Implements a configuration class for the engine of a model
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
----
|
|
29
|
+
providers: list of providers to use for inference ref.: https://onnxruntime.ai/docs/execution-providers/
|
|
30
|
+
session_options: configuration for the inference session ref.: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
providers: Optional[Union[List[Tuple[str, Dict[str, Any]]], List[str]]] = None,
|
|
36
|
+
session_options: Optional[SessionOptions] = None,
|
|
37
|
+
):
|
|
38
|
+
self._providers = providers or self._init_providers()
|
|
39
|
+
self._session_options = session_options or self._init_sess_opts()
|
|
40
|
+
|
|
41
|
+
def _init_providers(self) -> List[Tuple[str, Dict[str, Any]]]:
|
|
42
|
+
providers: Any = [("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
|
|
43
|
+
available_providers = get_available_providers()
|
|
44
|
+
if "CUDAExecutionProvider" in available_providers and get_device() == "GPU": # pragma: no cover
|
|
45
|
+
providers.insert(
|
|
46
|
+
0,
|
|
47
|
+
(
|
|
48
|
+
"CUDAExecutionProvider",
|
|
49
|
+
{
|
|
50
|
+
"device_id": 0,
|
|
51
|
+
"arena_extend_strategy": "kNextPowerOfTwo",
|
|
52
|
+
"cudnn_conv_algo_search": "DEFAULT",
|
|
53
|
+
"do_copy_in_default_stream": True,
|
|
54
|
+
},
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
return providers
|
|
58
|
+
|
|
59
|
+
def _init_sess_opts(self) -> SessionOptions:
|
|
60
|
+
session_options = SessionOptions()
|
|
61
|
+
session_options.enable_cpu_mem_arena = True
|
|
62
|
+
session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
|
63
|
+
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
64
|
+
session_options.intra_op_num_threads = -1
|
|
65
|
+
session_options.inter_op_num_threads = -1
|
|
66
|
+
return session_options
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def providers(self) -> Union[List[Tuple[str, Dict[str, Any]]], List[str]]:
|
|
70
|
+
return self._providers
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def session_options(self) -> SessionOptions:
|
|
74
|
+
return self._session_options
|
|
75
|
+
|
|
76
|
+
def __repr__(self) -> str:
|
|
77
|
+
return f"EngineConfig(providers={self.providers}"
|
|
78
|
+
|
|
14
79
|
|
|
15
80
|
class Engine:
|
|
16
81
|
"""Implements an abstract class for the engine of a model
|
|
@@ -18,15 +83,16 @@ class Engine:
|
|
|
18
83
|
Args:
|
|
19
84
|
----
|
|
20
85
|
url: the url to use to download a model if needed
|
|
21
|
-
|
|
86
|
+
engine_cfg: the configuration of the engine
|
|
22
87
|
**kwargs: additional arguments to be passed to `download_from_url`
|
|
23
88
|
"""
|
|
24
89
|
|
|
25
|
-
def __init__(
|
|
26
|
-
|
|
27
|
-
) -> None:
|
|
90
|
+
def __init__(self, url: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any) -> None:
|
|
91
|
+
engine_cfg = engine_cfg if isinstance(engine_cfg, EngineConfig) else EngineConfig()
|
|
28
92
|
archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
|
|
29
|
-
self.
|
|
93
|
+
self.session_options = engine_cfg.session_options
|
|
94
|
+
self.providers = engine_cfg.providers
|
|
95
|
+
self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options)
|
|
30
96
|
self.runtime_inputs = self.runtime.get_inputs()[0]
|
|
31
97
|
self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3
|
|
32
98
|
self.fixed_batch_size: Union[int, str] = self.runtime_inputs.shape[
|