onnxtr 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/contrib/__init__.py +1 -0
- onnxtr/contrib/artefacts.py +6 -8
- onnxtr/contrib/base.py +7 -16
- onnxtr/file_utils.py +1 -3
- onnxtr/io/elements.py +54 -60
- onnxtr/io/html.py +0 -2
- onnxtr/io/image.py +1 -4
- onnxtr/io/pdf.py +3 -5
- onnxtr/io/reader.py +4 -10
- onnxtr/models/_utils.py +10 -17
- onnxtr/models/builder.py +17 -30
- onnxtr/models/classification/models/mobilenet.py +7 -12
- onnxtr/models/classification/predictor/base.py +6 -7
- onnxtr/models/classification/zoo.py +25 -11
- onnxtr/models/detection/_utils/base.py +3 -7
- onnxtr/models/detection/core.py +2 -8
- onnxtr/models/detection/models/differentiable_binarization.py +10 -17
- onnxtr/models/detection/models/fast.py +10 -17
- onnxtr/models/detection/models/linknet.py +10 -17
- onnxtr/models/detection/postprocessor/base.py +3 -9
- onnxtr/models/detection/predictor/base.py +4 -5
- onnxtr/models/detection/zoo.py +20 -6
- onnxtr/models/engine.py +9 -9
- onnxtr/models/factory/hub.py +3 -7
- onnxtr/models/predictor/base.py +29 -30
- onnxtr/models/predictor/predictor.py +4 -5
- onnxtr/models/preprocessor/base.py +8 -12
- onnxtr/models/recognition/core.py +0 -1
- onnxtr/models/recognition/models/crnn.py +11 -23
- onnxtr/models/recognition/models/master.py +9 -15
- onnxtr/models/recognition/models/parseq.py +8 -12
- onnxtr/models/recognition/models/sar.py +8 -12
- onnxtr/models/recognition/models/vitstr.py +9 -15
- onnxtr/models/recognition/predictor/_utils.py +6 -9
- onnxtr/models/recognition/predictor/base.py +3 -3
- onnxtr/models/recognition/utils.py +2 -7
- onnxtr/models/recognition/zoo.py +19 -7
- onnxtr/models/zoo.py +7 -9
- onnxtr/transforms/base.py +17 -6
- onnxtr/utils/common_types.py +7 -8
- onnxtr/utils/data.py +7 -11
- onnxtr/utils/fonts.py +1 -6
- onnxtr/utils/geometry.py +18 -49
- onnxtr/utils/multithreading.py +3 -5
- onnxtr/utils/reconstitution.py +139 -38
- onnxtr/utils/repr.py +1 -2
- onnxtr/utils/visualization.py +12 -21
- onnxtr/utils/vocabs.py +1 -2
- onnxtr/version.py +1 -1
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/METADATA +71 -41
- onnxtr-0.6.0.dist-info/RECORD +75 -0
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/WHEEL +1 -1
- onnxtr-0.5.0.dist-info/RECORD +0 -75
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/LICENSE +0 -0
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/top_level.txt +0 -0
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/zip-safe +0 -0
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from scipy.special import expit
|
|
@@ -14,7 +14,7 @@ from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
|
14
14
|
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
default_cfgs:
|
|
17
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
18
18
|
"db_resnet50": {
|
|
19
19
|
"input_shape": (3, 1024, 1024),
|
|
20
20
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -43,7 +43,6 @@ class DBNet(Engine):
|
|
|
43
43
|
"""DBNet Onnx loader
|
|
44
44
|
|
|
45
45
|
Args:
|
|
46
|
-
----
|
|
47
46
|
model_path: path or url to onnx model file
|
|
48
47
|
engine_cfg: configuration for the inference engine
|
|
49
48
|
bin_thresh: threshold for binarization of the output feature map
|
|
@@ -56,11 +55,11 @@ class DBNet(Engine):
|
|
|
56
55
|
def __init__(
|
|
57
56
|
self,
|
|
58
57
|
model_path: str,
|
|
59
|
-
engine_cfg:
|
|
58
|
+
engine_cfg: EngineConfig | None = None,
|
|
60
59
|
bin_thresh: float = 0.3,
|
|
61
60
|
box_thresh: float = 0.1,
|
|
62
61
|
assume_straight_pages: bool = True,
|
|
63
|
-
cfg:
|
|
62
|
+
cfg: dict[str, Any] | None = None,
|
|
64
63
|
**kwargs: Any,
|
|
65
64
|
) -> None:
|
|
66
65
|
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
@@ -77,10 +76,10 @@ class DBNet(Engine):
|
|
|
77
76
|
x: np.ndarray,
|
|
78
77
|
return_model_output: bool = False,
|
|
79
78
|
**kwargs: Any,
|
|
80
|
-
) ->
|
|
79
|
+
) -> dict[str, Any]:
|
|
81
80
|
logits = self.run(x)
|
|
82
81
|
|
|
83
|
-
out:
|
|
82
|
+
out: dict[str, Any] = {}
|
|
84
83
|
|
|
85
84
|
prob_map = expit(logits)
|
|
86
85
|
if return_model_output:
|
|
@@ -95,7 +94,7 @@ def _dbnet(
|
|
|
95
94
|
arch: str,
|
|
96
95
|
model_path: str,
|
|
97
96
|
load_in_8_bit: bool = False,
|
|
98
|
-
engine_cfg:
|
|
97
|
+
engine_cfg: EngineConfig | None = None,
|
|
99
98
|
**kwargs: Any,
|
|
100
99
|
) -> DBNet:
|
|
101
100
|
# Patch the url
|
|
@@ -107,7 +106,7 @@ def _dbnet(
|
|
|
107
106
|
def db_resnet34(
|
|
108
107
|
model_path: str = default_cfgs["db_resnet34"]["url"],
|
|
109
108
|
load_in_8_bit: bool = False,
|
|
110
|
-
engine_cfg:
|
|
109
|
+
engine_cfg: EngineConfig | None = None,
|
|
111
110
|
**kwargs: Any,
|
|
112
111
|
) -> DBNet:
|
|
113
112
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
@@ -120,14 +119,12 @@ def db_resnet34(
|
|
|
120
119
|
>>> out = model(input_tensor)
|
|
121
120
|
|
|
122
121
|
Args:
|
|
123
|
-
----
|
|
124
122
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
125
123
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
126
124
|
engine_cfg: configuration for the inference engine
|
|
127
125
|
**kwargs: keyword arguments of the DBNet architecture
|
|
128
126
|
|
|
129
127
|
Returns:
|
|
130
|
-
-------
|
|
131
128
|
text detection architecture
|
|
132
129
|
"""
|
|
133
130
|
return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -136,7 +133,7 @@ def db_resnet34(
|
|
|
136
133
|
def db_resnet50(
|
|
137
134
|
model_path: str = default_cfgs["db_resnet50"]["url"],
|
|
138
135
|
load_in_8_bit: bool = False,
|
|
139
|
-
engine_cfg:
|
|
136
|
+
engine_cfg: EngineConfig | None = None,
|
|
140
137
|
**kwargs: Any,
|
|
141
138
|
) -> DBNet:
|
|
142
139
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
@@ -149,14 +146,12 @@ def db_resnet50(
|
|
|
149
146
|
>>> out = model(input_tensor)
|
|
150
147
|
|
|
151
148
|
Args:
|
|
152
|
-
----
|
|
153
149
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
154
150
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
155
151
|
engine_cfg: configuration for the inference engine
|
|
156
152
|
**kwargs: keyword arguments of the DBNet architecture
|
|
157
153
|
|
|
158
154
|
Returns:
|
|
159
|
-
-------
|
|
160
155
|
text detection architecture
|
|
161
156
|
"""
|
|
162
157
|
return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -165,7 +160,7 @@ def db_resnet50(
|
|
|
165
160
|
def db_mobilenet_v3_large(
|
|
166
161
|
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
|
|
167
162
|
load_in_8_bit: bool = False,
|
|
168
|
-
engine_cfg:
|
|
163
|
+
engine_cfg: EngineConfig | None = None,
|
|
169
164
|
**kwargs: Any,
|
|
170
165
|
) -> DBNet:
|
|
171
166
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
@@ -178,14 +173,12 @@ def db_mobilenet_v3_large(
|
|
|
178
173
|
>>> out = model(input_tensor)
|
|
179
174
|
|
|
180
175
|
Args:
|
|
181
|
-
----
|
|
182
176
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
183
177
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
184
178
|
engine_cfg: configuration for the inference engine
|
|
185
179
|
**kwargs: keyword arguments of the DBNet architecture
|
|
186
180
|
|
|
187
181
|
Returns:
|
|
188
|
-
-------
|
|
189
182
|
text detection architecture
|
|
190
183
|
"""
|
|
191
184
|
return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from scipy.special import expit
|
|
@@ -15,7 +15,7 @@ from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
|
15
15
|
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
default_cfgs:
|
|
18
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
19
19
|
"fast_tiny": {
|
|
20
20
|
"input_shape": (3, 1024, 1024),
|
|
21
21
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -41,7 +41,6 @@ class FAST(Engine):
|
|
|
41
41
|
"""FAST Onnx loader
|
|
42
42
|
|
|
43
43
|
Args:
|
|
44
|
-
----
|
|
45
44
|
model_path: path or url to onnx model file
|
|
46
45
|
engine_cfg: configuration for the inference engine
|
|
47
46
|
bin_thresh: threshold for binarization of the output feature map
|
|
@@ -54,11 +53,11 @@ class FAST(Engine):
|
|
|
54
53
|
def __init__(
|
|
55
54
|
self,
|
|
56
55
|
model_path: str,
|
|
57
|
-
engine_cfg:
|
|
56
|
+
engine_cfg: EngineConfig | None = None,
|
|
58
57
|
bin_thresh: float = 0.1,
|
|
59
58
|
box_thresh: float = 0.1,
|
|
60
59
|
assume_straight_pages: bool = True,
|
|
61
|
-
cfg:
|
|
60
|
+
cfg: dict[str, Any] | None = None,
|
|
62
61
|
**kwargs: Any,
|
|
63
62
|
) -> None:
|
|
64
63
|
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
@@ -75,10 +74,10 @@ class FAST(Engine):
|
|
|
75
74
|
x: np.ndarray,
|
|
76
75
|
return_model_output: bool = False,
|
|
77
76
|
**kwargs: Any,
|
|
78
|
-
) ->
|
|
77
|
+
) -> dict[str, Any]:
|
|
79
78
|
logits = self.run(x)
|
|
80
79
|
|
|
81
|
-
out:
|
|
80
|
+
out: dict[str, Any] = {}
|
|
82
81
|
|
|
83
82
|
prob_map = expit(logits)
|
|
84
83
|
if return_model_output:
|
|
@@ -93,7 +92,7 @@ def _fast(
|
|
|
93
92
|
arch: str,
|
|
94
93
|
model_path: str,
|
|
95
94
|
load_in_8_bit: bool = False,
|
|
96
|
-
engine_cfg:
|
|
95
|
+
engine_cfg: EngineConfig | None = None,
|
|
97
96
|
**kwargs: Any,
|
|
98
97
|
) -> FAST:
|
|
99
98
|
if load_in_8_bit:
|
|
@@ -105,7 +104,7 @@ def _fast(
|
|
|
105
104
|
def fast_tiny(
|
|
106
105
|
model_path: str = default_cfgs["fast_tiny"]["url"],
|
|
107
106
|
load_in_8_bit: bool = False,
|
|
108
|
-
engine_cfg:
|
|
107
|
+
engine_cfg: EngineConfig | None = None,
|
|
109
108
|
**kwargs: Any,
|
|
110
109
|
) -> FAST:
|
|
111
110
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
@@ -118,14 +117,12 @@ def fast_tiny(
|
|
|
118
117
|
>>> out = model(input_tensor)
|
|
119
118
|
|
|
120
119
|
Args:
|
|
121
|
-
----
|
|
122
120
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
123
121
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
124
122
|
engine_cfg: configuration for the inference engine
|
|
125
123
|
**kwargs: keyword arguments of the DBNet architecture
|
|
126
124
|
|
|
127
125
|
Returns:
|
|
128
|
-
-------
|
|
129
126
|
text detection architecture
|
|
130
127
|
"""
|
|
131
128
|
return _fast("fast_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -134,7 +131,7 @@ def fast_tiny(
|
|
|
134
131
|
def fast_small(
|
|
135
132
|
model_path: str = default_cfgs["fast_small"]["url"],
|
|
136
133
|
load_in_8_bit: bool = False,
|
|
137
|
-
engine_cfg:
|
|
134
|
+
engine_cfg: EngineConfig | None = None,
|
|
138
135
|
**kwargs: Any,
|
|
139
136
|
) -> FAST:
|
|
140
137
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
@@ -147,14 +144,12 @@ def fast_small(
|
|
|
147
144
|
>>> out = model(input_tensor)
|
|
148
145
|
|
|
149
146
|
Args:
|
|
150
|
-
----
|
|
151
147
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
152
148
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
153
149
|
engine_cfg: configuration for the inference engine
|
|
154
150
|
**kwargs: keyword arguments of the DBNet architecture
|
|
155
151
|
|
|
156
152
|
Returns:
|
|
157
|
-
-------
|
|
158
153
|
text detection architecture
|
|
159
154
|
"""
|
|
160
155
|
return _fast("fast_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -163,7 +158,7 @@ def fast_small(
|
|
|
163
158
|
def fast_base(
|
|
164
159
|
model_path: str = default_cfgs["fast_base"]["url"],
|
|
165
160
|
load_in_8_bit: bool = False,
|
|
166
|
-
engine_cfg:
|
|
161
|
+
engine_cfg: EngineConfig | None = None,
|
|
167
162
|
**kwargs: Any,
|
|
168
163
|
) -> FAST:
|
|
169
164
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
@@ -176,14 +171,12 @@ def fast_base(
|
|
|
176
171
|
>>> out = model(input_tensor)
|
|
177
172
|
|
|
178
173
|
Args:
|
|
179
|
-
----
|
|
180
174
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
181
175
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
182
176
|
engine_cfg: configuration for the inference engine
|
|
183
177
|
**kwargs: keyword arguments of the DBNet architecture
|
|
184
178
|
|
|
185
179
|
Returns:
|
|
186
|
-
-------
|
|
187
180
|
text detection architecture
|
|
188
181
|
"""
|
|
189
182
|
return _fast("fast_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from scipy.special import expit
|
|
@@ -14,7 +14,7 @@ from ..postprocessor.base import GeneralDetectionPostProcessor
|
|
|
14
14
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
default_cfgs:
|
|
17
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
18
18
|
"linknet_resnet18": {
|
|
19
19
|
"input_shape": (3, 1024, 1024),
|
|
20
20
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -43,7 +43,6 @@ class LinkNet(Engine):
|
|
|
43
43
|
"""LinkNet Onnx loader
|
|
44
44
|
|
|
45
45
|
Args:
|
|
46
|
-
----
|
|
47
46
|
model_path: path or url to onnx model file
|
|
48
47
|
engine_cfg: configuration for the inference engine
|
|
49
48
|
bin_thresh: threshold for binarization of the output feature map
|
|
@@ -56,11 +55,11 @@ class LinkNet(Engine):
|
|
|
56
55
|
def __init__(
|
|
57
56
|
self,
|
|
58
57
|
model_path: str,
|
|
59
|
-
engine_cfg:
|
|
58
|
+
engine_cfg: EngineConfig | None = None,
|
|
60
59
|
bin_thresh: float = 0.1,
|
|
61
60
|
box_thresh: float = 0.1,
|
|
62
61
|
assume_straight_pages: bool = True,
|
|
63
|
-
cfg:
|
|
62
|
+
cfg: dict[str, Any] | None = None,
|
|
64
63
|
**kwargs: Any,
|
|
65
64
|
) -> None:
|
|
66
65
|
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
@@ -77,10 +76,10 @@ class LinkNet(Engine):
|
|
|
77
76
|
x: np.ndarray,
|
|
78
77
|
return_model_output: bool = False,
|
|
79
78
|
**kwargs: Any,
|
|
80
|
-
) ->
|
|
79
|
+
) -> dict[str, Any]:
|
|
81
80
|
logits = self.run(x)
|
|
82
81
|
|
|
83
|
-
out:
|
|
82
|
+
out: dict[str, Any] = {}
|
|
84
83
|
|
|
85
84
|
prob_map = expit(logits)
|
|
86
85
|
if return_model_output:
|
|
@@ -95,7 +94,7 @@ def _linknet(
|
|
|
95
94
|
arch: str,
|
|
96
95
|
model_path: str,
|
|
97
96
|
load_in_8_bit: bool = False,
|
|
98
|
-
engine_cfg:
|
|
97
|
+
engine_cfg: EngineConfig | None = None,
|
|
99
98
|
**kwargs: Any,
|
|
100
99
|
) -> LinkNet:
|
|
101
100
|
# Patch the url
|
|
@@ -107,7 +106,7 @@ def _linknet(
|
|
|
107
106
|
def linknet_resnet18(
|
|
108
107
|
model_path: str = default_cfgs["linknet_resnet18"]["url"],
|
|
109
108
|
load_in_8_bit: bool = False,
|
|
110
|
-
engine_cfg:
|
|
109
|
+
engine_cfg: EngineConfig | None = None,
|
|
111
110
|
**kwargs: Any,
|
|
112
111
|
) -> LinkNet:
|
|
113
112
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
@@ -120,14 +119,12 @@ def linknet_resnet18(
|
|
|
120
119
|
>>> out = model(input_tensor)
|
|
121
120
|
|
|
122
121
|
Args:
|
|
123
|
-
----
|
|
124
122
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
125
123
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
126
124
|
engine_cfg: configuration for the inference engine
|
|
127
125
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
128
126
|
|
|
129
127
|
Returns:
|
|
130
|
-
-------
|
|
131
128
|
text detection architecture
|
|
132
129
|
"""
|
|
133
130
|
return _linknet("linknet_resnet18", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -136,7 +133,7 @@ def linknet_resnet18(
|
|
|
136
133
|
def linknet_resnet34(
|
|
137
134
|
model_path: str = default_cfgs["linknet_resnet34"]["url"],
|
|
138
135
|
load_in_8_bit: bool = False,
|
|
139
|
-
engine_cfg:
|
|
136
|
+
engine_cfg: EngineConfig | None = None,
|
|
140
137
|
**kwargs: Any,
|
|
141
138
|
) -> LinkNet:
|
|
142
139
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
@@ -149,14 +146,12 @@ def linknet_resnet34(
|
|
|
149
146
|
>>> out = model(input_tensor)
|
|
150
147
|
|
|
151
148
|
Args:
|
|
152
|
-
----
|
|
153
149
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
154
150
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
155
151
|
engine_cfg: configuration for the inference engine
|
|
156
152
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
157
153
|
|
|
158
154
|
Returns:
|
|
159
|
-
-------
|
|
160
155
|
text detection architecture
|
|
161
156
|
"""
|
|
162
157
|
return _linknet("linknet_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -165,7 +160,7 @@ def linknet_resnet34(
|
|
|
165
160
|
def linknet_resnet50(
|
|
166
161
|
model_path: str = default_cfgs["linknet_resnet50"]["url"],
|
|
167
162
|
load_in_8_bit: bool = False,
|
|
168
|
-
engine_cfg:
|
|
163
|
+
engine_cfg: EngineConfig | None = None,
|
|
169
164
|
**kwargs: Any,
|
|
170
165
|
) -> LinkNet:
|
|
171
166
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
@@ -178,14 +173,12 @@ def linknet_resnet50(
|
|
|
178
173
|
>>> out = model(input_tensor)
|
|
179
174
|
|
|
180
175
|
Args:
|
|
181
|
-
----
|
|
182
176
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
183
177
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
184
178
|
engine_cfg: configuration for the inference engine
|
|
185
179
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
186
180
|
|
|
187
181
|
Returns:
|
|
188
|
-
-------
|
|
189
182
|
text detection architecture
|
|
190
183
|
"""
|
|
191
184
|
return _linknet("linknet_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -5,7 +5,6 @@
|
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import List, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -21,7 +20,6 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
|
|
|
21
20
|
"""Implements a post processor for FAST model.
|
|
22
21
|
|
|
23
22
|
Args:
|
|
24
|
-
----
|
|
25
23
|
bin_thresh: threshold used to binzarized p_map at inference time
|
|
26
24
|
box_thresh: minimal objectness score to consider a box
|
|
27
25
|
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
@@ -43,11 +41,9 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
|
|
|
43
41
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
44
42
|
|
|
45
43
|
Args:
|
|
46
|
-
----
|
|
47
44
|
points: The first parameter.
|
|
48
45
|
|
|
49
46
|
Returns:
|
|
50
|
-
-------
|
|
51
47
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
52
48
|
"""
|
|
53
49
|
if not self.assume_straight_pages:
|
|
@@ -92,24 +88,22 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
|
|
|
92
88
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
93
89
|
|
|
94
90
|
Args:
|
|
95
|
-
----
|
|
96
91
|
pred: Pred map from differentiable linknet output
|
|
97
92
|
bitmap: Bitmap map computed from pred (binarized)
|
|
98
93
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
99
94
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
100
95
|
|
|
101
96
|
Returns:
|
|
102
|
-
-------
|
|
103
97
|
np tensor boxes for the bitmap, each box is a 6-element list
|
|
104
|
-
|
|
98
|
+
containing x, y, w, h, alpha, score for the box
|
|
105
99
|
"""
|
|
106
100
|
height, width = bitmap.shape[:2]
|
|
107
|
-
boxes:
|
|
101
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
108
102
|
# get contours from connected components on the bitmap
|
|
109
103
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
110
104
|
for contour in contours:
|
|
111
105
|
# Check whether smallest enclosing bounding box is not too small
|
|
112
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
106
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
113
107
|
continue
|
|
114
108
|
# Compute objectness
|
|
115
109
|
if self.assume_straight_pages:
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
@@ -18,12 +18,11 @@ class DetectionPredictor(NestedObject):
|
|
|
18
18
|
"""Implements an object able to localize text elements in a document
|
|
19
19
|
|
|
20
20
|
Args:
|
|
21
|
-
----
|
|
22
21
|
pre_processor: transform inputs for easier batched model inference
|
|
23
22
|
model: core detection architecture
|
|
24
23
|
"""
|
|
25
24
|
|
|
26
|
-
_children_names:
|
|
25
|
+
_children_names: list[str] = ["pre_processor", "model"]
|
|
27
26
|
|
|
28
27
|
def __init__(
|
|
29
28
|
self,
|
|
@@ -35,10 +34,10 @@ class DetectionPredictor(NestedObject):
|
|
|
35
34
|
|
|
36
35
|
def __call__(
|
|
37
36
|
self,
|
|
38
|
-
pages:
|
|
37
|
+
pages: list[np.ndarray],
|
|
39
38
|
return_maps: bool = False,
|
|
40
39
|
**kwargs: Any,
|
|
41
|
-
) ->
|
|
40
|
+
) -> list[np.ndarray] | tuple[list[np.ndarray], list[np.ndarray]]:
|
|
42
41
|
# Extract parameters from the preprocessor
|
|
43
42
|
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
|
|
44
43
|
symmetric_pad = self.pre_processor.resize.symmetric_pad
|
onnxtr/models/detection/zoo.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from .. import detection
|
|
9
9
|
from ..engine import EngineConfig
|
|
@@ -29,7 +29,7 @@ def _predictor(
|
|
|
29
29
|
arch: Any,
|
|
30
30
|
assume_straight_pages: bool = True,
|
|
31
31
|
load_in_8_bit: bool = False,
|
|
32
|
-
engine_cfg:
|
|
32
|
+
engine_cfg: EngineConfig | None = None,
|
|
33
33
|
**kwargs: Any,
|
|
34
34
|
) -> DetectionPredictor:
|
|
35
35
|
if isinstance(arch, str):
|
|
@@ -59,8 +59,11 @@ def _predictor(
|
|
|
59
59
|
def detection_predictor(
|
|
60
60
|
arch: Any = "fast_base",
|
|
61
61
|
assume_straight_pages: bool = True,
|
|
62
|
+
preserve_aspect_ratio: bool = True,
|
|
63
|
+
symmetric_pad: bool = True,
|
|
64
|
+
batch_size: int = 2,
|
|
62
65
|
load_in_8_bit: bool = False,
|
|
63
|
-
engine_cfg:
|
|
66
|
+
engine_cfg: EngineConfig | None = None,
|
|
64
67
|
**kwargs: Any,
|
|
65
68
|
) -> DetectionPredictor:
|
|
66
69
|
"""Text detection architecture.
|
|
@@ -72,15 +75,26 @@ def detection_predictor(
|
|
|
72
75
|
>>> out = model([input_page])
|
|
73
76
|
|
|
74
77
|
Args:
|
|
75
|
-
----
|
|
76
78
|
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
|
|
77
79
|
assume_straight_pages: If True, fit straight boxes to the page
|
|
80
|
+
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
|
|
81
|
+
running the detection model on it
|
|
82
|
+
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
|
|
83
|
+
batch_size: number of samples the model processes in parallel
|
|
78
84
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
79
85
|
engine_cfg: configuration for the inference engine
|
|
80
86
|
**kwargs: optional keyword arguments passed to the architecture
|
|
81
87
|
|
|
82
88
|
Returns:
|
|
83
|
-
-------
|
|
84
89
|
Detection predictor
|
|
85
90
|
"""
|
|
86
|
-
return _predictor(
|
|
91
|
+
return _predictor(
|
|
92
|
+
arch=arch,
|
|
93
|
+
assume_straight_pages=assume_straight_pages,
|
|
94
|
+
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
95
|
+
symmetric_pad=symmetric_pad,
|
|
96
|
+
batch_size=batch_size,
|
|
97
|
+
load_in_8_bit=load_in_8_bit,
|
|
98
|
+
engine_cfg=engine_cfg,
|
|
99
|
+
**kwargs,
|
|
100
|
+
)
|
onnxtr/models/engine.py
CHANGED
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
from onnxruntime import (
|
|
@@ -25,22 +26,22 @@ class EngineConfig:
|
|
|
25
26
|
"""Implements a configuration class for the engine of a model
|
|
26
27
|
|
|
27
28
|
Args:
|
|
28
|
-
----
|
|
29
29
|
providers: list of providers to use for inference ref.: https://onnxruntime.ai/docs/execution-providers/
|
|
30
30
|
session_options: configuration for the inference session ref.: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
def __init__(
|
|
34
34
|
self,
|
|
35
|
-
providers:
|
|
36
|
-
session_options:
|
|
35
|
+
providers: list[tuple[str, dict[str, Any]]] | list[str] | None = None,
|
|
36
|
+
session_options: SessionOptions | None = None,
|
|
37
37
|
):
|
|
38
38
|
self._providers = providers or self._init_providers()
|
|
39
39
|
self._session_options = session_options or self._init_sess_opts()
|
|
40
40
|
|
|
41
|
-
def _init_providers(self) ->
|
|
41
|
+
def _init_providers(self) -> list[tuple[str, dict[str, Any]]]:
|
|
42
42
|
providers: Any = [("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
|
|
43
43
|
available_providers = get_available_providers()
|
|
44
|
+
logging.info(f"Available providers: {available_providers}")
|
|
44
45
|
if "CUDAExecutionProvider" in available_providers and get_device() == "GPU": # pragma: no cover
|
|
45
46
|
providers.insert(
|
|
46
47
|
0,
|
|
@@ -66,7 +67,7 @@ class EngineConfig:
|
|
|
66
67
|
return session_options
|
|
67
68
|
|
|
68
69
|
@property
|
|
69
|
-
def providers(self) ->
|
|
70
|
+
def providers(self) -> list[tuple[str, dict[str, Any]]] | list[str]:
|
|
70
71
|
return self._providers
|
|
71
72
|
|
|
72
73
|
@property
|
|
@@ -81,13 +82,12 @@ class Engine:
|
|
|
81
82
|
"""Implements an abstract class for the engine of a model
|
|
82
83
|
|
|
83
84
|
Args:
|
|
84
|
-
----
|
|
85
85
|
url: the url to use to download a model if needed
|
|
86
86
|
engine_cfg: the configuration of the engine
|
|
87
87
|
**kwargs: additional arguments to be passed to `download_from_url`
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
|
-
def __init__(self, url: str, engine_cfg:
|
|
90
|
+
def __init__(self, url: str, engine_cfg: EngineConfig | None = None, **kwargs: Any) -> None:
|
|
91
91
|
engine_cfg = engine_cfg if isinstance(engine_cfg, EngineConfig) else EngineConfig()
|
|
92
92
|
archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
|
|
93
93
|
# Store model path for each model
|
|
@@ -97,7 +97,7 @@ class Engine:
|
|
|
97
97
|
self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options)
|
|
98
98
|
self.runtime_inputs = self.runtime.get_inputs()[0]
|
|
99
99
|
self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3
|
|
100
|
-
self.fixed_batch_size:
|
|
100
|
+
self.fixed_batch_size: int | str = self.runtime_inputs.shape[
|
|
101
101
|
0
|
|
102
102
|
] # mostly possible with tensorflow exported models
|
|
103
103
|
self.output_name = [output.name for output in self.runtime.get_outputs()]
|
onnxtr/models/factory/hub.py
CHANGED
|
@@ -12,7 +12,7 @@ import shutil
|
|
|
12
12
|
import subprocess
|
|
13
13
|
import textwrap
|
|
14
14
|
from pathlib import Path
|
|
15
|
-
from typing import Any
|
|
15
|
+
from typing import Any
|
|
16
16
|
|
|
17
17
|
from huggingface_hub import (
|
|
18
18
|
HfApi,
|
|
@@ -59,7 +59,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
|
|
|
59
59
|
"""Save model and config to disk for pushing to huggingface hub
|
|
60
60
|
|
|
61
61
|
Args:
|
|
62
|
-
----
|
|
63
62
|
model: Onnx model to be saved
|
|
64
63
|
save_dir: directory to save model and config
|
|
65
64
|
arch: architecture name
|
|
@@ -91,7 +90,6 @@ def push_to_hf_hub(
|
|
|
91
90
|
>>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
|
|
92
91
|
|
|
93
92
|
Args:
|
|
94
|
-
----
|
|
95
93
|
model: Onnx model to be saved
|
|
96
94
|
model_name: name of the model which is also the repository name
|
|
97
95
|
task: task name
|
|
@@ -179,20 +177,18 @@ def push_to_hf_hub(
|
|
|
179
177
|
repo.git_push()
|
|
180
178
|
|
|
181
179
|
|
|
182
|
-
def from_hub(repo_id: str, engine_cfg:
|
|
180
|
+
def from_hub(repo_id: str, engine_cfg: EngineConfig | None = None, **kwargs: Any):
|
|
183
181
|
"""Instantiate & load a pretrained model from HF hub.
|
|
184
182
|
|
|
185
183
|
>>> from onnxtr.models import from_hub
|
|
186
184
|
>>> model = from_hub("onnxtr/my-model")
|
|
187
185
|
|
|
188
186
|
Args:
|
|
189
|
-
----
|
|
190
187
|
repo_id: HuggingFace model hub repo
|
|
191
188
|
engine_cfg: configuration for the inference engine (optional)
|
|
192
|
-
kwargs: kwargs of `hf_hub_download`
|
|
189
|
+
**kwargs: kwargs of `hf_hub_download`
|
|
193
190
|
|
|
194
191
|
Returns:
|
|
195
|
-
-------
|
|
196
192
|
Model loaded with the checkpoint
|
|
197
193
|
"""
|
|
198
194
|
# Get the config
|