onnxtr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. onnxtr/__init__.py +2 -0
  2. onnxtr/contrib/__init__.py +0 -0
  3. onnxtr/contrib/artefacts.py +131 -0
  4. onnxtr/contrib/base.py +105 -0
  5. onnxtr/file_utils.py +33 -0
  6. onnxtr/io/__init__.py +5 -0
  7. onnxtr/io/elements.py +455 -0
  8. onnxtr/io/html.py +28 -0
  9. onnxtr/io/image.py +56 -0
  10. onnxtr/io/pdf.py +42 -0
  11. onnxtr/io/reader.py +85 -0
  12. onnxtr/models/__init__.py +4 -0
  13. onnxtr/models/_utils.py +141 -0
  14. onnxtr/models/builder.py +355 -0
  15. onnxtr/models/classification/__init__.py +2 -0
  16. onnxtr/models/classification/models/__init__.py +1 -0
  17. onnxtr/models/classification/models/mobilenet.py +120 -0
  18. onnxtr/models/classification/predictor/__init__.py +1 -0
  19. onnxtr/models/classification/predictor/base.py +57 -0
  20. onnxtr/models/classification/zoo.py +76 -0
  21. onnxtr/models/detection/__init__.py +2 -0
  22. onnxtr/models/detection/core.py +101 -0
  23. onnxtr/models/detection/models/__init__.py +3 -0
  24. onnxtr/models/detection/models/differentiable_binarization.py +159 -0
  25. onnxtr/models/detection/models/fast.py +160 -0
  26. onnxtr/models/detection/models/linknet.py +160 -0
  27. onnxtr/models/detection/postprocessor/__init__.py +0 -0
  28. onnxtr/models/detection/postprocessor/base.py +144 -0
  29. onnxtr/models/detection/predictor/__init__.py +1 -0
  30. onnxtr/models/detection/predictor/base.py +54 -0
  31. onnxtr/models/detection/zoo.py +73 -0
  32. onnxtr/models/engine.py +50 -0
  33. onnxtr/models/predictor/__init__.py +1 -0
  34. onnxtr/models/predictor/base.py +175 -0
  35. onnxtr/models/predictor/predictor.py +145 -0
  36. onnxtr/models/preprocessor/__init__.py +1 -0
  37. onnxtr/models/preprocessor/base.py +118 -0
  38. onnxtr/models/recognition/__init__.py +2 -0
  39. onnxtr/models/recognition/core.py +28 -0
  40. onnxtr/models/recognition/models/__init__.py +5 -0
  41. onnxtr/models/recognition/models/crnn.py +226 -0
  42. onnxtr/models/recognition/models/master.py +145 -0
  43. onnxtr/models/recognition/models/parseq.py +134 -0
  44. onnxtr/models/recognition/models/sar.py +134 -0
  45. onnxtr/models/recognition/models/vitstr.py +166 -0
  46. onnxtr/models/recognition/predictor/__init__.py +1 -0
  47. onnxtr/models/recognition/predictor/_utils.py +86 -0
  48. onnxtr/models/recognition/predictor/base.py +79 -0
  49. onnxtr/models/recognition/utils.py +89 -0
  50. onnxtr/models/recognition/zoo.py +69 -0
  51. onnxtr/models/zoo.py +114 -0
  52. onnxtr/transforms/__init__.py +1 -0
  53. onnxtr/transforms/base.py +112 -0
  54. onnxtr/utils/__init__.py +4 -0
  55. onnxtr/utils/common_types.py +18 -0
  56. onnxtr/utils/data.py +126 -0
  57. onnxtr/utils/fonts.py +41 -0
  58. onnxtr/utils/geometry.py +498 -0
  59. onnxtr/utils/multithreading.py +50 -0
  60. onnxtr/utils/reconstitution.py +70 -0
  61. onnxtr/utils/repr.py +64 -0
  62. onnxtr/utils/visualization.py +291 -0
  63. onnxtr/utils/vocabs.py +71 -0
  64. onnxtr/version.py +1 -0
  65. onnxtr-0.1.0.dist-info/LICENSE +201 -0
  66. onnxtr-0.1.0.dist-info/METADATA +481 -0
  67. onnxtr-0.1.0.dist-info/RECORD +70 -0
  68. onnxtr-0.1.0.dist-info/WHEEL +5 -0
  69. onnxtr-0.1.0.dist-info/top_level.txt +2 -0
  70. onnxtr-0.1.0.dist-info/zip-safe +1 -0
onnxtr/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from . import io, models, contrib, transforms, utils
2
+ from .version import __version__ # noqa: F401
File without changes
@@ -0,0 +1,131 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from onnxtr.file_utils import requires_package
12
+
13
+ from .base import _BasePredictor
14
+
15
+ __all__ = ["ArtefactDetector"]
16
+
17
+ default_cfgs: Dict[str, Dict[str, Any]] = {
18
+ "yolov8_artefact": {
19
+ "input_shape": (3, 1024, 1024),
20
+ "labels": ["bar_code", "qr_code", "logo", "photo"],
21
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/yolo_artefact-f9d66f14.onnx",
22
+ },
23
+ }
24
+
25
+
26
+ class ArtefactDetector(_BasePredictor):
27
+ """
28
+ A class to detect artefacts in images
29
+
30
+ >>> from onnxtr.io import DocumentFile
31
+ >>> from onnxtr.contrib.artefacts import ArtefactDetector
32
+ >>> doc = DocumentFile.from_images(["path/to/image.jpg"])
33
+ >>> detector = ArtefactDetector()
34
+ >>> results = detector(doc)
35
+
36
+ Args:
37
+ ----
38
+ arch: the architecture to use
39
+ batch_size: the batch size to use
40
+ model_path: the path to the model to use
41
+ labels: the labels to use
42
+ input_shape: the input shape to use
43
+ mask_labels: the mask labels to use
44
+ conf_threshold: the confidence threshold to use
45
+ iou_threshold: the intersection over union threshold to use
46
+ **kwargs: additional arguments to be passed to `download_from_url`
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ arch: str = "yolov8_artefact",
52
+ batch_size: int = 2,
53
+ model_path: Optional[str] = None,
54
+ labels: Optional[List[str]] = None,
55
+ input_shape: Optional[Tuple[int, int, int]] = None,
56
+ conf_threshold: float = 0.5,
57
+ iou_threshold: float = 0.5,
58
+ **kwargs: Any,
59
+ ) -> None:
60
+ super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs)
61
+ self.labels = labels or default_cfgs[arch]["labels"]
62
+ self.input_shape = input_shape or default_cfgs[arch]["input_shape"]
63
+ self.conf_threshold = conf_threshold
64
+ self.iou_threshold = iou_threshold
65
+
66
+ def preprocess(self, img: np.ndarray) -> np.ndarray:
67
+ return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
68
+
69
+ def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]:
70
+ results = []
71
+
72
+ for batch in zip(output, input_images):
73
+ for out, img in zip(batch[0], batch[1]):
74
+ org_height, org_width = img.shape[:2]
75
+ width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1]
76
+ for res in out:
77
+ sample_results = []
78
+ for row in np.transpose(np.squeeze(res)):
79
+ classes_scores = row[4:]
80
+ max_score = np.amax(classes_scores)
81
+ if max_score >= self.conf_threshold:
82
+ class_id = np.argmax(classes_scores)
83
+ x, y, w, h = row[0], row[1], row[2], row[3]
84
+ # to rescaled xmin, ymin, xmax, ymax
85
+ xmin = int((x - w / 2) * width_scale)
86
+ ymin = int((y - h / 2) * height_scale)
87
+ xmax = int((x + w / 2) * width_scale)
88
+ ymax = int((y + h / 2) * height_scale)
89
+
90
+ sample_results.append({
91
+ "label": self.labels[class_id],
92
+ "confidence": float(max_score),
93
+ "box": [xmin, ymin, xmax, ymax],
94
+ })
95
+
96
+ # Filter out overlapping boxes
97
+ boxes = [res["box"] for res in sample_results]
98
+ scores = [res["confidence"] for res in sample_results]
99
+ keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type]
100
+ sample_results = [sample_results[i] for i in keep_indices]
101
+
102
+ results.append(sample_results)
103
+
104
+ self._results = results
105
+ return results
106
+
107
+ def show(self, **kwargs: Any) -> None:
108
+ """
109
+ Display the results
110
+
111
+ Args:
112
+ ----
113
+ **kwargs: additional keyword arguments to be passed to `plt.show`
114
+ """
115
+ requires_package("matplotlib", "`.show()` requires matplotlib installed")
116
+ import matplotlib.pyplot as plt
117
+ from matplotlib.patches import Rectangle
118
+
119
+ # visualize the results with matplotlib
120
+ if self._results and self._inputs:
121
+ for img, res in zip(self._inputs, self._results):
122
+ plt.figure(figsize=(10, 10))
123
+ plt.imshow(img)
124
+ for obj in res:
125
+ xmin, ymin, xmax, ymax = obj["box"]
126
+ label = obj["label"]
127
+ plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red")
128
+ plt.gca().add_patch(
129
+ Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2)
130
+ )
131
+ plt.show(**kwargs)
onnxtr/contrib/base.py ADDED
@@ -0,0 +1,105 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List, Optional
7
+
8
+ import numpy as np
9
+
10
+ from onnxtr.file_utils import requires_package
11
+ from onnxtr.utils.data import download_from_url
12
+
13
+
14
+ class _BasePredictor:
15
+ """
16
+ Base class for all predictors
17
+
18
+ Args:
19
+ ----
20
+ batch_size: the batch size to use
21
+ url: the url to use to download a model if needed
22
+ model_path: the path to the model to use
23
+ **kwargs: additional arguments to be passed to `download_from_url`
24
+ """
25
+
26
+ def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None:
27
+ self.batch_size = batch_size
28
+ self.session = self._init_model(url, model_path, **kwargs)
29
+
30
+ self._inputs: List[np.ndarray] = []
31
+ self._results: List[Any] = []
32
+
33
+ def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any:
34
+ """
35
+ Download the model from the given url if needed
36
+
37
+ Args:
38
+ ----
39
+ url: the url to use
40
+ model_path: the path to the model to use
41
+ **kwargs: additional arguments to be passed to `download_from_url`
42
+
43
+ Returns:
44
+ -------
45
+ Any: the ONNX loaded model
46
+ """
47
+ requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
48
+ import onnxruntime as ort
49
+
50
+ if not url and not model_path:
51
+ raise ValueError("You must provide either a url or a model_path")
52
+ onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type]
53
+ return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
54
+
55
+ def preprocess(self, img: np.ndarray) -> np.ndarray:
56
+ """
57
+ Preprocess the input image
58
+
59
+ Args:
60
+ ----
61
+ img: the input image to preprocess
62
+
63
+ Returns:
64
+ -------
65
+ np.ndarray: the preprocessed image
66
+ """
67
+ raise NotImplementedError
68
+
69
+ def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any:
70
+ """
71
+ Postprocess the model output
72
+
73
+ Args:
74
+ ----
75
+ output: the model output to postprocess
76
+ input_images: the input images used to generate the output
77
+
78
+ Returns:
79
+ -------
80
+ Any: the postprocessed output
81
+ """
82
+ raise NotImplementedError
83
+
84
+ def __call__(self, inputs: List[np.ndarray]) -> Any:
85
+ """
86
+ Call the model on the given inputs
87
+
88
+ Args:
89
+ ----
90
+ inputs: the inputs to use
91
+
92
+ Returns:
93
+ -------
94
+ Any: the postprocessed output
95
+ """
96
+ self._inputs = inputs
97
+ model_inputs = self.session.get_inputs()
98
+
99
+ batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)]
100
+ processed_batches = [
101
+ np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs
102
+ ]
103
+
104
+ outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches]
105
+ return self.postprocess(outputs, batched_inputs)
onnxtr/file_utils.py ADDED
@@ -0,0 +1,33 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import importlib.metadata
7
+ import importlib.util
8
+ import logging
9
+ from typing import Optional
10
+
11
+ __all__ = ["requires_package"]
12
+
13
+ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
14
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
15
+
16
+
17
+ def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover
18
+ """
19
+ package requirement helper
20
+
21
+ Args:
22
+ ----
23
+ name: name of the package
24
+ extra_message: additional message to display if the package is not found
25
+ """
26
+ try:
27
+ _pkg_version = importlib.metadata.version(name)
28
+ logging.info(f"{name} version {_pkg_version} available.")
29
+ except importlib.metadata.PackageNotFoundError:
30
+ raise ImportError(
31
+ f"\n\n{extra_message if extra_message is not None else ''} "
32
+ f"\nPlease install it with the following command: pip install {name}\n"
33
+ )
onnxtr/io/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from .elements import *
2
+ from .html import *
3
+ from .image import *
4
+ from .pdf import *
5
+ from .reader import *