onnxtr 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/contrib/__init__.py +1 -0
- onnxtr/contrib/artefacts.py +6 -8
- onnxtr/contrib/base.py +7 -16
- onnxtr/file_utils.py +1 -3
- onnxtr/io/elements.py +54 -60
- onnxtr/io/html.py +0 -2
- onnxtr/io/image.py +1 -4
- onnxtr/io/pdf.py +3 -5
- onnxtr/io/reader.py +4 -10
- onnxtr/models/_utils.py +10 -17
- onnxtr/models/builder.py +17 -30
- onnxtr/models/classification/models/mobilenet.py +7 -12
- onnxtr/models/classification/predictor/base.py +6 -7
- onnxtr/models/classification/zoo.py +25 -11
- onnxtr/models/detection/_utils/base.py +3 -7
- onnxtr/models/detection/core.py +2 -8
- onnxtr/models/detection/models/differentiable_binarization.py +10 -17
- onnxtr/models/detection/models/fast.py +10 -17
- onnxtr/models/detection/models/linknet.py +10 -17
- onnxtr/models/detection/postprocessor/base.py +3 -9
- onnxtr/models/detection/predictor/base.py +4 -5
- onnxtr/models/detection/zoo.py +20 -6
- onnxtr/models/engine.py +9 -9
- onnxtr/models/factory/hub.py +3 -7
- onnxtr/models/predictor/base.py +29 -30
- onnxtr/models/predictor/predictor.py +4 -5
- onnxtr/models/preprocessor/base.py +8 -12
- onnxtr/models/recognition/core.py +0 -1
- onnxtr/models/recognition/models/crnn.py +11 -23
- onnxtr/models/recognition/models/master.py +9 -15
- onnxtr/models/recognition/models/parseq.py +8 -12
- onnxtr/models/recognition/models/sar.py +8 -12
- onnxtr/models/recognition/models/vitstr.py +9 -15
- onnxtr/models/recognition/predictor/_utils.py +6 -9
- onnxtr/models/recognition/predictor/base.py +3 -3
- onnxtr/models/recognition/utils.py +2 -7
- onnxtr/models/recognition/zoo.py +19 -7
- onnxtr/models/zoo.py +7 -9
- onnxtr/transforms/base.py +17 -6
- onnxtr/utils/common_types.py +7 -8
- onnxtr/utils/data.py +7 -11
- onnxtr/utils/fonts.py +1 -6
- onnxtr/utils/geometry.py +18 -49
- onnxtr/utils/multithreading.py +3 -5
- onnxtr/utils/reconstitution.py +139 -38
- onnxtr/utils/repr.py +1 -2
- onnxtr/utils/visualization.py +12 -21
- onnxtr/utils/vocabs.py +1 -2
- onnxtr/version.py +1 -1
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/METADATA +71 -41
- onnxtr-0.6.0.dist-info/RECORD +75 -0
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/WHEEL +1 -1
- onnxtr-0.5.0.dist-info/RECORD +0 -75
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/LICENSE +0 -0
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/top_level.txt +0 -0
- {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/zip-safe +0 -0
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from scipy.special import softmax
|
|
@@ -16,7 +16,7 @@ from ..core import RecognitionPostProcessor
|
|
|
16
16
|
|
|
17
17
|
__all__ = ["SAR", "sar_resnet31"]
|
|
18
18
|
|
|
19
|
-
default_cfgs:
|
|
19
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
20
|
"sar_resnet31": {
|
|
21
21
|
"mean": (0.694, 0.695, 0.693),
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -32,7 +32,6 @@ class SAR(Engine):
|
|
|
32
32
|
"""SAR Onnx loader
|
|
33
33
|
|
|
34
34
|
Args:
|
|
35
|
-
----
|
|
36
35
|
model_path: path to onnx model file
|
|
37
36
|
vocab: vocabulary used for encoding
|
|
38
37
|
engine_cfg: configuration for the inference engine
|
|
@@ -44,8 +43,8 @@ class SAR(Engine):
|
|
|
44
43
|
self,
|
|
45
44
|
model_path: str,
|
|
46
45
|
vocab: str,
|
|
47
|
-
engine_cfg:
|
|
48
|
-
cfg:
|
|
46
|
+
engine_cfg: EngineConfig | None = None,
|
|
47
|
+
cfg: dict[str, Any] | None = None,
|
|
49
48
|
**kwargs: Any,
|
|
50
49
|
) -> None:
|
|
51
50
|
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
@@ -59,10 +58,10 @@ class SAR(Engine):
|
|
|
59
58
|
self,
|
|
60
59
|
x: np.ndarray,
|
|
61
60
|
return_model_output: bool = False,
|
|
62
|
-
) ->
|
|
61
|
+
) -> dict[str, Any]:
|
|
63
62
|
logits = self.run(x)
|
|
64
63
|
|
|
65
|
-
out:
|
|
64
|
+
out: dict[str, Any] = {}
|
|
66
65
|
if return_model_output:
|
|
67
66
|
out["out_map"] = logits
|
|
68
67
|
|
|
@@ -75,7 +74,6 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
75
74
|
"""Post processor for SAR architectures
|
|
76
75
|
|
|
77
76
|
Args:
|
|
78
|
-
----
|
|
79
77
|
embedding: string containing the ordered sequence of supported characters
|
|
80
78
|
"""
|
|
81
79
|
|
|
@@ -105,7 +103,7 @@ def _sar(
|
|
|
105
103
|
arch: str,
|
|
106
104
|
model_path: str,
|
|
107
105
|
load_in_8_bit: bool = False,
|
|
108
|
-
engine_cfg:
|
|
106
|
+
engine_cfg: EngineConfig | None = None,
|
|
109
107
|
**kwargs: Any,
|
|
110
108
|
) -> SAR:
|
|
111
109
|
# Patch the config
|
|
@@ -124,7 +122,7 @@ def _sar(
|
|
|
124
122
|
def sar_resnet31(
|
|
125
123
|
model_path: str = default_cfgs["sar_resnet31"]["url"],
|
|
126
124
|
load_in_8_bit: bool = False,
|
|
127
|
-
engine_cfg:
|
|
125
|
+
engine_cfg: EngineConfig | None = None,
|
|
128
126
|
**kwargs: Any,
|
|
129
127
|
) -> SAR:
|
|
130
128
|
"""SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
|
|
@@ -137,14 +135,12 @@ def sar_resnet31(
|
|
|
137
135
|
>>> out = model(input_tensor)
|
|
138
136
|
|
|
139
137
|
Args:
|
|
140
|
-
----
|
|
141
138
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
142
139
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
143
140
|
engine_cfg: configuration for the inference engine
|
|
144
141
|
**kwargs: keyword arguments of the SAR architecture
|
|
145
142
|
|
|
146
143
|
Returns:
|
|
147
|
-
-------
|
|
148
144
|
text recognition architecture
|
|
149
145
|
"""
|
|
150
146
|
return _sar("sar_resnet31", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from scipy.special import softmax
|
|
@@ -16,7 +16,7 @@ from ..core import RecognitionPostProcessor
|
|
|
16
16
|
|
|
17
17
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
18
18
|
|
|
19
|
-
default_cfgs:
|
|
19
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
20
|
"vitstr_small": {
|
|
21
21
|
"mean": (0.694, 0.695, 0.693),
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -40,7 +40,6 @@ class ViTSTR(Engine):
|
|
|
40
40
|
"""ViTSTR Onnx loader
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
|
-
----
|
|
44
43
|
model_path: path to onnx model file
|
|
45
44
|
vocab: vocabulary used for encoding
|
|
46
45
|
engine_cfg: configuration for the inference engine
|
|
@@ -52,8 +51,8 @@ class ViTSTR(Engine):
|
|
|
52
51
|
self,
|
|
53
52
|
model_path: str,
|
|
54
53
|
vocab: str,
|
|
55
|
-
engine_cfg:
|
|
56
|
-
cfg:
|
|
54
|
+
engine_cfg: EngineConfig | None = None,
|
|
55
|
+
cfg: dict[str, Any] | None = None,
|
|
57
56
|
**kwargs: Any,
|
|
58
57
|
) -> None:
|
|
59
58
|
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
|
|
@@ -67,10 +66,10 @@ class ViTSTR(Engine):
|
|
|
67
66
|
self,
|
|
68
67
|
x: np.ndarray,
|
|
69
68
|
return_model_output: bool = False,
|
|
70
|
-
) ->
|
|
69
|
+
) -> dict[str, Any]:
|
|
71
70
|
logits = self.run(x)
|
|
72
71
|
|
|
73
|
-
out:
|
|
72
|
+
out: dict[str, Any] = {}
|
|
74
73
|
if return_model_output:
|
|
75
74
|
out["out_map"] = logits
|
|
76
75
|
|
|
@@ -83,7 +82,6 @@ class ViTSTRPostProcessor(RecognitionPostProcessor):
|
|
|
83
82
|
"""Post processor for ViTSTR architecture
|
|
84
83
|
|
|
85
84
|
Args:
|
|
86
|
-
----
|
|
87
85
|
vocab: string containing the ordered sequence of supported characters
|
|
88
86
|
"""
|
|
89
87
|
|
|
@@ -115,7 +113,7 @@ def _vitstr(
|
|
|
115
113
|
arch: str,
|
|
116
114
|
model_path: str,
|
|
117
115
|
load_in_8_bit: bool = False,
|
|
118
|
-
engine_cfg:
|
|
116
|
+
engine_cfg: EngineConfig | None = None,
|
|
119
117
|
**kwargs: Any,
|
|
120
118
|
) -> ViTSTR:
|
|
121
119
|
# Patch the config
|
|
@@ -134,7 +132,7 @@ def _vitstr(
|
|
|
134
132
|
def vitstr_small(
|
|
135
133
|
model_path: str = default_cfgs["vitstr_small"]["url"],
|
|
136
134
|
load_in_8_bit: bool = False,
|
|
137
|
-
engine_cfg:
|
|
135
|
+
engine_cfg: EngineConfig | None = None,
|
|
138
136
|
**kwargs: Any,
|
|
139
137
|
) -> ViTSTR:
|
|
140
138
|
"""ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
|
|
@@ -147,14 +145,12 @@ def vitstr_small(
|
|
|
147
145
|
>>> out = model(input_tensor)
|
|
148
146
|
|
|
149
147
|
Args:
|
|
150
|
-
----
|
|
151
148
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
152
149
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
153
150
|
engine_cfg: configuration for the inference engine
|
|
154
151
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
155
152
|
|
|
156
153
|
Returns:
|
|
157
|
-
-------
|
|
158
154
|
text recognition architecture
|
|
159
155
|
"""
|
|
160
156
|
return _vitstr("vitstr_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -163,7 +159,7 @@ def vitstr_small(
|
|
|
163
159
|
def vitstr_base(
|
|
164
160
|
model_path: str = default_cfgs["vitstr_base"]["url"],
|
|
165
161
|
load_in_8_bit: bool = False,
|
|
166
|
-
engine_cfg:
|
|
162
|
+
engine_cfg: EngineConfig | None = None,
|
|
167
163
|
**kwargs: Any,
|
|
168
164
|
) -> ViTSTR:
|
|
169
165
|
"""ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
|
|
@@ -176,14 +172,12 @@ def vitstr_base(
|
|
|
176
172
|
>>> out = model(input_tensor)
|
|
177
173
|
|
|
178
174
|
Args:
|
|
179
|
-
----
|
|
180
175
|
model_path: path to onnx model file, defaults to url in default_cfgs
|
|
181
176
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
182
177
|
engine_cfg: configuration for the inference engine
|
|
183
178
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
184
179
|
|
|
185
180
|
Returns:
|
|
186
|
-
-------
|
|
187
181
|
text recognition architecture
|
|
188
182
|
"""
|
|
189
183
|
return _vitstr("vitstr_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -13,16 +12,15 @@ __all__ = ["split_crops", "remap_preds"]
|
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
def split_crops(
|
|
16
|
-
crops:
|
|
15
|
+
crops: list[np.ndarray],
|
|
17
16
|
max_ratio: float,
|
|
18
17
|
target_ratio: int,
|
|
19
18
|
dilation: float,
|
|
20
19
|
channels_last: bool = True,
|
|
21
|
-
) ->
|
|
20
|
+
) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
|
|
22
21
|
"""Chunk crops horizontally to match a given aspect ratio
|
|
23
22
|
|
|
24
23
|
Args:
|
|
25
|
-
----
|
|
26
24
|
crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
|
|
27
25
|
max_ratio: the maximum aspect ratio that won't trigger the chunk
|
|
28
26
|
target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
|
|
@@ -30,12 +28,11 @@ def split_crops(
|
|
|
30
28
|
channels_last: whether the numpy array has dimensions in channels last order
|
|
31
29
|
|
|
32
30
|
Returns:
|
|
33
|
-
-------
|
|
34
31
|
a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
|
|
35
32
|
"""
|
|
36
33
|
_remap_required = False
|
|
37
|
-
crop_map:
|
|
38
|
-
new_crops:
|
|
34
|
+
crop_map: list[int | tuple[int, int]] = []
|
|
35
|
+
new_crops: list[np.ndarray] = []
|
|
39
36
|
for crop in crops:
|
|
40
37
|
h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
|
|
41
38
|
aspect_ratio = w / h
|
|
@@ -71,8 +68,8 @@ def split_crops(
|
|
|
71
68
|
|
|
72
69
|
|
|
73
70
|
def remap_preds(
|
|
74
|
-
preds:
|
|
75
|
-
) ->
|
|
71
|
+
preds: list[tuple[str, float]], crop_map: list[int | tuple[int, int]], dilation: float
|
|
72
|
+
) -> list[tuple[str, float]]:
|
|
76
73
|
remapped_out = []
|
|
77
74
|
for _idx in crop_map:
|
|
78
75
|
# Crop hasn't been split
|
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
|
|
@@ -19,7 +20,6 @@ class RecognitionPredictor(NestedObject):
|
|
|
19
20
|
"""Implements an object able to identify character sequences in images
|
|
20
21
|
|
|
21
22
|
Args:
|
|
22
|
-
----
|
|
23
23
|
pre_processor: transform inputs for easier batched model inference
|
|
24
24
|
model: core recognition architecture
|
|
25
25
|
split_wide_crops: wether to use crop splitting for high aspect ratio crops
|
|
@@ -43,7 +43,7 @@ class RecognitionPredictor(NestedObject):
|
|
|
43
43
|
self,
|
|
44
44
|
crops: Sequence[np.ndarray],
|
|
45
45
|
**kwargs: Any,
|
|
46
|
-
) ->
|
|
46
|
+
) -> list[tuple[str, float]]:
|
|
47
47
|
if len(crops) == 0:
|
|
48
48
|
return []
|
|
49
49
|
# Dimension check
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List
|
|
7
6
|
|
|
8
7
|
from rapidfuzz.distance import Levenshtein
|
|
9
8
|
|
|
@@ -14,14 +13,12 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
|
|
|
14
13
|
"""Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
|
|
15
14
|
|
|
16
15
|
Args:
|
|
17
|
-
----
|
|
18
16
|
a: first char seq, suffix should be similar to b's prefix.
|
|
19
17
|
b: second char seq, prefix should be similar to a's suffix.
|
|
20
18
|
dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
|
|
21
19
|
only used when the mother sequence is splitted on a character repetition
|
|
22
20
|
|
|
23
21
|
Returns:
|
|
24
|
-
-------
|
|
25
22
|
A merged character sequence.
|
|
26
23
|
|
|
27
24
|
Example::
|
|
@@ -61,17 +58,15 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
|
|
|
61
58
|
return a[:-1] + b[index - 1 :]
|
|
62
59
|
|
|
63
60
|
|
|
64
|
-
def merge_multi_strings(seq_list:
|
|
61
|
+
def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str:
|
|
65
62
|
"""Recursively merges consecutive string sequences with overlapping characters.
|
|
66
63
|
|
|
67
64
|
Args:
|
|
68
|
-
----
|
|
69
65
|
seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
|
|
70
66
|
dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
|
|
71
67
|
only used when the mother sequence is splitted on a character repetition
|
|
72
68
|
|
|
73
69
|
Returns:
|
|
74
|
-
-------
|
|
75
70
|
A merged character sequence
|
|
76
71
|
|
|
77
72
|
Example::
|
|
@@ -80,7 +75,7 @@ def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
|
|
|
80
75
|
'abcdefghijkl'
|
|
81
76
|
"""
|
|
82
77
|
|
|
83
|
-
def _recursive_merge(a: str, seq_list:
|
|
78
|
+
def _recursive_merge(a: str, seq_list: list[str], dil_factor: float) -> str:
|
|
84
79
|
# Recursive version of compute_overlap
|
|
85
80
|
if len(seq_list) == 1:
|
|
86
81
|
return merge_strings(a, seq_list[0], dil_factor)
|
onnxtr/models/recognition/zoo.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from .. import recognition
|
|
9
9
|
from ..engine import EngineConfig
|
|
@@ -13,7 +13,7 @@ from .predictor import RecognitionPredictor
|
|
|
13
13
|
__all__ = ["recognition_predictor"]
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
ARCHS:
|
|
16
|
+
ARCHS: list[str] = [
|
|
17
17
|
"crnn_vgg16_bn",
|
|
18
18
|
"crnn_mobilenet_v3_small",
|
|
19
19
|
"crnn_mobilenet_v3_large",
|
|
@@ -26,7 +26,7 @@ ARCHS: List[str] = [
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def _predictor(
|
|
29
|
-
arch: Any, load_in_8_bit: bool = False, engine_cfg:
|
|
29
|
+
arch: Any, load_in_8_bit: bool = False, engine_cfg: EngineConfig | None = None, **kwargs: Any
|
|
30
30
|
) -> RecognitionPredictor:
|
|
31
31
|
if isinstance(arch, str):
|
|
32
32
|
if arch not in ARCHS:
|
|
@@ -50,7 +50,12 @@ def _predictor(
|
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
def recognition_predictor(
|
|
53
|
-
arch: Any = "crnn_vgg16_bn",
|
|
53
|
+
arch: Any = "crnn_vgg16_bn",
|
|
54
|
+
symmetric_pad: bool = False,
|
|
55
|
+
batch_size: int = 128,
|
|
56
|
+
load_in_8_bit: bool = False,
|
|
57
|
+
engine_cfg: EngineConfig | None = None,
|
|
58
|
+
**kwargs: Any,
|
|
54
59
|
) -> RecognitionPredictor:
|
|
55
60
|
"""Text recognition architecture.
|
|
56
61
|
|
|
@@ -62,14 +67,21 @@ def recognition_predictor(
|
|
|
62
67
|
>>> out = model([input_page])
|
|
63
68
|
|
|
64
69
|
Args:
|
|
65
|
-
----
|
|
66
70
|
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
|
|
71
|
+
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
|
|
72
|
+
batch_size: number of samples the model processes in parallel
|
|
67
73
|
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
|
|
68
74
|
engine_cfg: configuration of inference engine
|
|
69
75
|
**kwargs: optional parameters to be passed to the architecture
|
|
70
76
|
|
|
71
77
|
Returns:
|
|
72
|
-
-------
|
|
73
78
|
Recognition predictor
|
|
74
79
|
"""
|
|
75
|
-
return _predictor(
|
|
80
|
+
return _predictor(
|
|
81
|
+
arch=arch,
|
|
82
|
+
symmetric_pad=symmetric_pad,
|
|
83
|
+
batch_size=batch_size,
|
|
84
|
+
load_in_8_bit=load_in_8_bit,
|
|
85
|
+
engine_cfg=engine_cfg,
|
|
86
|
+
**kwargs,
|
|
87
|
+
)
|
onnxtr/models/zoo.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from .detection.zoo import detection_predictor
|
|
9
9
|
from .engine import EngineConfig
|
|
@@ -25,9 +25,9 @@ def _predictor(
|
|
|
25
25
|
straighten_pages: bool = False,
|
|
26
26
|
detect_language: bool = False,
|
|
27
27
|
load_in_8_bit: bool = False,
|
|
28
|
-
det_engine_cfg:
|
|
29
|
-
reco_engine_cfg:
|
|
30
|
-
clf_engine_cfg:
|
|
28
|
+
det_engine_cfg: EngineConfig | None = None,
|
|
29
|
+
reco_engine_cfg: EngineConfig | None = None,
|
|
30
|
+
clf_engine_cfg: EngineConfig | None = None,
|
|
31
31
|
**kwargs,
|
|
32
32
|
) -> OCRPredictor:
|
|
33
33
|
# Detection
|
|
@@ -74,9 +74,9 @@ def ocr_predictor(
|
|
|
74
74
|
straighten_pages: bool = False,
|
|
75
75
|
detect_language: bool = False,
|
|
76
76
|
load_in_8_bit: bool = False,
|
|
77
|
-
det_engine_cfg:
|
|
78
|
-
reco_engine_cfg:
|
|
79
|
-
clf_engine_cfg:
|
|
77
|
+
det_engine_cfg: EngineConfig | None = None,
|
|
78
|
+
reco_engine_cfg: EngineConfig | None = None,
|
|
79
|
+
clf_engine_cfg: EngineConfig | None = None,
|
|
80
80
|
**kwargs: Any,
|
|
81
81
|
) -> OCRPredictor:
|
|
82
82
|
"""End-to-end OCR architecture using one model for localization, and another for text recognition.
|
|
@@ -88,7 +88,6 @@ def ocr_predictor(
|
|
|
88
88
|
>>> out = model([input_page])
|
|
89
89
|
|
|
90
90
|
Args:
|
|
91
|
-
----
|
|
92
91
|
det_arch: name of the detection architecture or the model itself to use
|
|
93
92
|
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
|
|
94
93
|
reco_arch: name of the recognition architecture or the model itself to use
|
|
@@ -115,7 +114,6 @@ def ocr_predictor(
|
|
|
115
114
|
kwargs: keyword args of `OCRPredictor`
|
|
116
115
|
|
|
117
116
|
Returns:
|
|
118
|
-
-------
|
|
119
117
|
OCR predictor
|
|
120
118
|
"""
|
|
121
119
|
return _predictor(
|
onnxtr/transforms/base.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Tuple, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
from PIL import Image, ImageOps
|
|
@@ -12,11 +11,18 @@ __all__ = ["Resize", "Normalize"]
|
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
class Resize:
|
|
15
|
-
"""Resize the input image to the given size
|
|
14
|
+
"""Resize the input image to the given size
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
size: the target size of the image
|
|
18
|
+
interpolation: the interpolation method to use
|
|
19
|
+
preserve_aspect_ratio: whether to preserve the aspect ratio of the image
|
|
20
|
+
symmetric_pad: whether to symmetrically pad the image
|
|
21
|
+
"""
|
|
16
22
|
|
|
17
23
|
def __init__(
|
|
18
24
|
self,
|
|
19
|
-
size:
|
|
25
|
+
size: int | tuple[int, int],
|
|
20
26
|
interpolation=Image.Resampling.BILINEAR,
|
|
21
27
|
preserve_aspect_ratio: bool = False,
|
|
22
28
|
symmetric_pad: bool = False,
|
|
@@ -72,12 +78,17 @@ class Resize:
|
|
|
72
78
|
|
|
73
79
|
|
|
74
80
|
class Normalize:
|
|
75
|
-
"""Normalize the input image
|
|
81
|
+
"""Normalize the input image
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
mean: mean values to subtract
|
|
85
|
+
std: standard deviation values to divide
|
|
86
|
+
"""
|
|
76
87
|
|
|
77
88
|
def __init__(
|
|
78
89
|
self,
|
|
79
|
-
mean:
|
|
80
|
-
std:
|
|
90
|
+
mean: float | tuple[float, float, float] = (0.485, 0.456, 0.406),
|
|
91
|
+
std: float | tuple[float, float, float] = (0.229, 0.224, 0.225),
|
|
81
92
|
) -> None:
|
|
82
93
|
self.mean = mean
|
|
83
94
|
self.std = std
|
onnxtr/utils/common_types.py
CHANGED
|
@@ -4,15 +4,14 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import List, Tuple, Union
|
|
8
7
|
|
|
9
8
|
__all__ = ["Point2D", "BoundingBox", "Polygon4P", "Polygon", "Bbox"]
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
Point2D =
|
|
13
|
-
BoundingBox =
|
|
14
|
-
Polygon4P =
|
|
15
|
-
Polygon =
|
|
16
|
-
AbstractPath =
|
|
17
|
-
AbstractFile =
|
|
18
|
-
Bbox =
|
|
11
|
+
Point2D = tuple[float, float]
|
|
12
|
+
BoundingBox = tuple[Point2D, Point2D]
|
|
13
|
+
Polygon4P = tuple[Point2D, Point2D, Point2D, Point2D]
|
|
14
|
+
Polygon = list[Point2D]
|
|
15
|
+
AbstractPath = str | Path
|
|
16
|
+
AbstractFile = AbstractPath | bytes
|
|
17
|
+
Bbox = tuple[float, float, float, float]
|
onnxtr/utils/data.py
CHANGED
|
@@ -13,7 +13,6 @@ import urllib
|
|
|
13
13
|
import urllib.error
|
|
14
14
|
import urllib.request
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import Optional, Union
|
|
17
16
|
|
|
18
17
|
from tqdm.auto import tqdm
|
|
19
18
|
|
|
@@ -25,7 +24,7 @@ HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
|
|
|
25
24
|
USER_AGENT = "felixdittrich92/OnnxTR"
|
|
26
25
|
|
|
27
26
|
|
|
28
|
-
def _urlretrieve(url: str, filename:
|
|
27
|
+
def _urlretrieve(url: str, filename: Path | str, chunk_size: int = 1024) -> None:
|
|
29
28
|
with open(filename, "wb") as fh:
|
|
30
29
|
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
|
|
31
30
|
with tqdm(total=response.length) as pbar:
|
|
@@ -36,7 +35,7 @@ def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -
|
|
|
36
35
|
fh.write(chunk)
|
|
37
36
|
|
|
38
37
|
|
|
39
|
-
def _check_integrity(file_path:
|
|
38
|
+
def _check_integrity(file_path: str | Path, hash_prefix: str) -> bool:
|
|
40
39
|
with open(file_path, "rb") as f:
|
|
41
40
|
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
|
42
41
|
|
|
@@ -45,10 +44,10 @@ def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool:
|
|
|
45
44
|
|
|
46
45
|
def download_from_url(
|
|
47
46
|
url: str,
|
|
48
|
-
file_name:
|
|
49
|
-
hash_prefix:
|
|
50
|
-
cache_dir:
|
|
51
|
-
cache_subdir:
|
|
47
|
+
file_name: str | None = None,
|
|
48
|
+
hash_prefix: str | None = None,
|
|
49
|
+
cache_dir: str | None = None,
|
|
50
|
+
cache_subdir: str | None = None,
|
|
52
51
|
) -> Path:
|
|
53
52
|
"""Download a file using its URL
|
|
54
53
|
|
|
@@ -56,7 +55,6 @@ def download_from_url(
|
|
|
56
55
|
>>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip")
|
|
57
56
|
|
|
58
57
|
Args:
|
|
59
|
-
----
|
|
60
58
|
url: the URL of the file to download
|
|
61
59
|
file_name: optional name of the file once downloaded
|
|
62
60
|
hash_prefix: optional expected SHA256 hash of the file
|
|
@@ -64,11 +62,9 @@ def download_from_url(
|
|
|
64
62
|
cache_subdir: subfolder to use in the cache
|
|
65
63
|
|
|
66
64
|
Returns:
|
|
67
|
-
-------
|
|
68
65
|
the location of the downloaded file
|
|
69
66
|
|
|
70
67
|
Note:
|
|
71
|
-
----
|
|
72
68
|
You can change cache directory location by using `ONNXTR_CACHE_DIR` environment variable.
|
|
73
69
|
"""
|
|
74
70
|
if not isinstance(file_name, str):
|
|
@@ -112,7 +108,7 @@ def download_from_url(
|
|
|
112
108
|
except (urllib.error.URLError, IOError) as e: # pragma: no cover
|
|
113
109
|
if url[:5] == "https":
|
|
114
110
|
url = url.replace("https:", "http:")
|
|
115
|
-
print("Failed download. Trying https -> http instead.
|
|
111
|
+
print(f"Failed download. Trying https -> http instead. Downloading {url} to {file_path}")
|
|
116
112
|
_urlretrieve(url, file_path)
|
|
117
113
|
else:
|
|
118
114
|
raise e
|
onnxtr/utils/fonts.py
CHANGED
|
@@ -5,25 +5,20 @@
|
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
7
|
import platform
|
|
8
|
-
from typing import Optional, Union
|
|
9
8
|
|
|
10
9
|
from PIL import ImageFont
|
|
11
10
|
|
|
12
11
|
__all__ = ["get_font"]
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
def get_font(
|
|
16
|
-
font_family: Optional[str] = None, font_size: int = 13
|
|
17
|
-
) -> Union[ImageFont.FreeTypeFont, ImageFont.ImageFont]:
|
|
14
|
+
def get_font(font_family: str | None = None, font_size: int = 13) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
|
|
18
15
|
"""Resolves a compatible ImageFont for the system
|
|
19
16
|
|
|
20
17
|
Args:
|
|
21
|
-
----
|
|
22
18
|
font_family: the font family to use
|
|
23
19
|
font_size: the size of the font upon rendering
|
|
24
20
|
|
|
25
21
|
Returns:
|
|
26
|
-
-------
|
|
27
22
|
the Pillow font
|
|
28
23
|
"""
|
|
29
24
|
# Font selection
|