onnxtr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxtr/__init__.py +2 -0
- onnxtr/contrib/__init__.py +0 -0
- onnxtr/contrib/artefacts.py +131 -0
- onnxtr/contrib/base.py +105 -0
- onnxtr/file_utils.py +33 -0
- onnxtr/io/__init__.py +5 -0
- onnxtr/io/elements.py +455 -0
- onnxtr/io/html.py +28 -0
- onnxtr/io/image.py +56 -0
- onnxtr/io/pdf.py +42 -0
- onnxtr/io/reader.py +85 -0
- onnxtr/models/__init__.py +4 -0
- onnxtr/models/_utils.py +141 -0
- onnxtr/models/builder.py +355 -0
- onnxtr/models/classification/__init__.py +2 -0
- onnxtr/models/classification/models/__init__.py +1 -0
- onnxtr/models/classification/models/mobilenet.py +120 -0
- onnxtr/models/classification/predictor/__init__.py +1 -0
- onnxtr/models/classification/predictor/base.py +57 -0
- onnxtr/models/classification/zoo.py +76 -0
- onnxtr/models/detection/__init__.py +2 -0
- onnxtr/models/detection/core.py +101 -0
- onnxtr/models/detection/models/__init__.py +3 -0
- onnxtr/models/detection/models/differentiable_binarization.py +159 -0
- onnxtr/models/detection/models/fast.py +160 -0
- onnxtr/models/detection/models/linknet.py +160 -0
- onnxtr/models/detection/postprocessor/__init__.py +0 -0
- onnxtr/models/detection/postprocessor/base.py +144 -0
- onnxtr/models/detection/predictor/__init__.py +1 -0
- onnxtr/models/detection/predictor/base.py +54 -0
- onnxtr/models/detection/zoo.py +73 -0
- onnxtr/models/engine.py +50 -0
- onnxtr/models/predictor/__init__.py +1 -0
- onnxtr/models/predictor/base.py +175 -0
- onnxtr/models/predictor/predictor.py +145 -0
- onnxtr/models/preprocessor/__init__.py +1 -0
- onnxtr/models/preprocessor/base.py +118 -0
- onnxtr/models/recognition/__init__.py +2 -0
- onnxtr/models/recognition/core.py +28 -0
- onnxtr/models/recognition/models/__init__.py +5 -0
- onnxtr/models/recognition/models/crnn.py +226 -0
- onnxtr/models/recognition/models/master.py +145 -0
- onnxtr/models/recognition/models/parseq.py +134 -0
- onnxtr/models/recognition/models/sar.py +134 -0
- onnxtr/models/recognition/models/vitstr.py +166 -0
- onnxtr/models/recognition/predictor/__init__.py +1 -0
- onnxtr/models/recognition/predictor/_utils.py +86 -0
- onnxtr/models/recognition/predictor/base.py +79 -0
- onnxtr/models/recognition/utils.py +89 -0
- onnxtr/models/recognition/zoo.py +69 -0
- onnxtr/models/zoo.py +114 -0
- onnxtr/transforms/__init__.py +1 -0
- onnxtr/transforms/base.py +112 -0
- onnxtr/utils/__init__.py +4 -0
- onnxtr/utils/common_types.py +18 -0
- onnxtr/utils/data.py +126 -0
- onnxtr/utils/fonts.py +41 -0
- onnxtr/utils/geometry.py +498 -0
- onnxtr/utils/multithreading.py +50 -0
- onnxtr/utils/reconstitution.py +70 -0
- onnxtr/utils/repr.py +64 -0
- onnxtr/utils/visualization.py +291 -0
- onnxtr/utils/vocabs.py +71 -0
- onnxtr/version.py +1 -0
- onnxtr-0.1.0.dist-info/LICENSE +201 -0
- onnxtr-0.1.0.dist-info/METADATA +481 -0
- onnxtr-0.1.0.dist-info/RECORD +70 -0
- onnxtr-0.1.0.dist-info/WHEEL +5 -0
- onnxtr-0.1.0.dist-info/top_level.txt +2 -0
- onnxtr-0.1.0.dist-info/zip-safe +1 -0
onnxtr/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from onnxtr.file_utils import requires_package
|
|
12
|
+
|
|
13
|
+
from .base import _BasePredictor
|
|
14
|
+
|
|
15
|
+
__all__ = ["ArtefactDetector"]
|
|
16
|
+
|
|
17
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
18
|
+
"yolov8_artefact": {
|
|
19
|
+
"input_shape": (3, 1024, 1024),
|
|
20
|
+
"labels": ["bar_code", "qr_code", "logo", "photo"],
|
|
21
|
+
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/yolo_artefact-f9d66f14.onnx",
|
|
22
|
+
},
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ArtefactDetector(_BasePredictor):
|
|
27
|
+
"""
|
|
28
|
+
A class to detect artefacts in images
|
|
29
|
+
|
|
30
|
+
>>> from onnxtr.io import DocumentFile
|
|
31
|
+
>>> from onnxtr.contrib.artefacts import ArtefactDetector
|
|
32
|
+
>>> doc = DocumentFile.from_images(["path/to/image.jpg"])
|
|
33
|
+
>>> detector = ArtefactDetector()
|
|
34
|
+
>>> results = detector(doc)
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
----
|
|
38
|
+
arch: the architecture to use
|
|
39
|
+
batch_size: the batch size to use
|
|
40
|
+
model_path: the path to the model to use
|
|
41
|
+
labels: the labels to use
|
|
42
|
+
input_shape: the input shape to use
|
|
43
|
+
mask_labels: the mask labels to use
|
|
44
|
+
conf_threshold: the confidence threshold to use
|
|
45
|
+
iou_threshold: the intersection over union threshold to use
|
|
46
|
+
**kwargs: additional arguments to be passed to `download_from_url`
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
arch: str = "yolov8_artefact",
|
|
52
|
+
batch_size: int = 2,
|
|
53
|
+
model_path: Optional[str] = None,
|
|
54
|
+
labels: Optional[List[str]] = None,
|
|
55
|
+
input_shape: Optional[Tuple[int, int, int]] = None,
|
|
56
|
+
conf_threshold: float = 0.5,
|
|
57
|
+
iou_threshold: float = 0.5,
|
|
58
|
+
**kwargs: Any,
|
|
59
|
+
) -> None:
|
|
60
|
+
super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs)
|
|
61
|
+
self.labels = labels or default_cfgs[arch]["labels"]
|
|
62
|
+
self.input_shape = input_shape or default_cfgs[arch]["input_shape"]
|
|
63
|
+
self.conf_threshold = conf_threshold
|
|
64
|
+
self.iou_threshold = iou_threshold
|
|
65
|
+
|
|
66
|
+
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
67
|
+
return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
|
|
68
|
+
|
|
69
|
+
def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]:
|
|
70
|
+
results = []
|
|
71
|
+
|
|
72
|
+
for batch in zip(output, input_images):
|
|
73
|
+
for out, img in zip(batch[0], batch[1]):
|
|
74
|
+
org_height, org_width = img.shape[:2]
|
|
75
|
+
width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1]
|
|
76
|
+
for res in out:
|
|
77
|
+
sample_results = []
|
|
78
|
+
for row in np.transpose(np.squeeze(res)):
|
|
79
|
+
classes_scores = row[4:]
|
|
80
|
+
max_score = np.amax(classes_scores)
|
|
81
|
+
if max_score >= self.conf_threshold:
|
|
82
|
+
class_id = np.argmax(classes_scores)
|
|
83
|
+
x, y, w, h = row[0], row[1], row[2], row[3]
|
|
84
|
+
# to rescaled xmin, ymin, xmax, ymax
|
|
85
|
+
xmin = int((x - w / 2) * width_scale)
|
|
86
|
+
ymin = int((y - h / 2) * height_scale)
|
|
87
|
+
xmax = int((x + w / 2) * width_scale)
|
|
88
|
+
ymax = int((y + h / 2) * height_scale)
|
|
89
|
+
|
|
90
|
+
sample_results.append({
|
|
91
|
+
"label": self.labels[class_id],
|
|
92
|
+
"confidence": float(max_score),
|
|
93
|
+
"box": [xmin, ymin, xmax, ymax],
|
|
94
|
+
})
|
|
95
|
+
|
|
96
|
+
# Filter out overlapping boxes
|
|
97
|
+
boxes = [res["box"] for res in sample_results]
|
|
98
|
+
scores = [res["confidence"] for res in sample_results]
|
|
99
|
+
keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type]
|
|
100
|
+
sample_results = [sample_results[i] for i in keep_indices]
|
|
101
|
+
|
|
102
|
+
results.append(sample_results)
|
|
103
|
+
|
|
104
|
+
self._results = results
|
|
105
|
+
return results
|
|
106
|
+
|
|
107
|
+
def show(self, **kwargs: Any) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Display the results
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
----
|
|
113
|
+
**kwargs: additional keyword arguments to be passed to `plt.show`
|
|
114
|
+
"""
|
|
115
|
+
requires_package("matplotlib", "`.show()` requires matplotlib installed")
|
|
116
|
+
import matplotlib.pyplot as plt
|
|
117
|
+
from matplotlib.patches import Rectangle
|
|
118
|
+
|
|
119
|
+
# visualize the results with matplotlib
|
|
120
|
+
if self._results and self._inputs:
|
|
121
|
+
for img, res in zip(self._inputs, self._results):
|
|
122
|
+
plt.figure(figsize=(10, 10))
|
|
123
|
+
plt.imshow(img)
|
|
124
|
+
for obj in res:
|
|
125
|
+
xmin, ymin, xmax, ymax = obj["box"]
|
|
126
|
+
label = obj["label"]
|
|
127
|
+
plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red")
|
|
128
|
+
plt.gca().add_patch(
|
|
129
|
+
Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2)
|
|
130
|
+
)
|
|
131
|
+
plt.show(**kwargs)
|
onnxtr/contrib/base.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, List, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from onnxtr.file_utils import requires_package
|
|
11
|
+
from onnxtr.utils.data import download_from_url
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _BasePredictor:
|
|
15
|
+
"""
|
|
16
|
+
Base class for all predictors
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
----
|
|
20
|
+
batch_size: the batch size to use
|
|
21
|
+
url: the url to use to download a model if needed
|
|
22
|
+
model_path: the path to the model to use
|
|
23
|
+
**kwargs: additional arguments to be passed to `download_from_url`
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None:
|
|
27
|
+
self.batch_size = batch_size
|
|
28
|
+
self.session = self._init_model(url, model_path, **kwargs)
|
|
29
|
+
|
|
30
|
+
self._inputs: List[np.ndarray] = []
|
|
31
|
+
self._results: List[Any] = []
|
|
32
|
+
|
|
33
|
+
def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any:
|
|
34
|
+
"""
|
|
35
|
+
Download the model from the given url if needed
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
----
|
|
39
|
+
url: the url to use
|
|
40
|
+
model_path: the path to the model to use
|
|
41
|
+
**kwargs: additional arguments to be passed to `download_from_url`
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
-------
|
|
45
|
+
Any: the ONNX loaded model
|
|
46
|
+
"""
|
|
47
|
+
requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
|
|
48
|
+
import onnxruntime as ort
|
|
49
|
+
|
|
50
|
+
if not url and not model_path:
|
|
51
|
+
raise ValueError("You must provide either a url or a model_path")
|
|
52
|
+
onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type]
|
|
53
|
+
return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
|
54
|
+
|
|
55
|
+
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
56
|
+
"""
|
|
57
|
+
Preprocess the input image
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
----
|
|
61
|
+
img: the input image to preprocess
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
-------
|
|
65
|
+
np.ndarray: the preprocessed image
|
|
66
|
+
"""
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any:
|
|
70
|
+
"""
|
|
71
|
+
Postprocess the model output
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
----
|
|
75
|
+
output: the model output to postprocess
|
|
76
|
+
input_images: the input images used to generate the output
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
-------
|
|
80
|
+
Any: the postprocessed output
|
|
81
|
+
"""
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
def __call__(self, inputs: List[np.ndarray]) -> Any:
|
|
85
|
+
"""
|
|
86
|
+
Call the model on the given inputs
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
----
|
|
90
|
+
inputs: the inputs to use
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
-------
|
|
94
|
+
Any: the postprocessed output
|
|
95
|
+
"""
|
|
96
|
+
self._inputs = inputs
|
|
97
|
+
model_inputs = self.session.get_inputs()
|
|
98
|
+
|
|
99
|
+
batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)]
|
|
100
|
+
processed_batches = [
|
|
101
|
+
np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches]
|
|
105
|
+
return self.postprocess(outputs, batched_inputs)
|
onnxtr/file_utils.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
import importlib.metadata
|
|
7
|
+
import importlib.util
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
__all__ = ["requires_package"]
|
|
12
|
+
|
|
13
|
+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
|
14
|
+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover
|
|
18
|
+
"""
|
|
19
|
+
package requirement helper
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
----
|
|
23
|
+
name: name of the package
|
|
24
|
+
extra_message: additional message to display if the package is not found
|
|
25
|
+
"""
|
|
26
|
+
try:
|
|
27
|
+
_pkg_version = importlib.metadata.version(name)
|
|
28
|
+
logging.info(f"{name} version {_pkg_version} available.")
|
|
29
|
+
except importlib.metadata.PackageNotFoundError:
|
|
30
|
+
raise ImportError(
|
|
31
|
+
f"\n\n{extra_message if extra_message is not None else ''} "
|
|
32
|
+
f"\nPlease install it with the following command: pip install {name}\n"
|
|
33
|
+
)
|