onnxtr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. onnxtr/__init__.py +2 -0
  2. onnxtr/contrib/__init__.py +0 -0
  3. onnxtr/contrib/artefacts.py +131 -0
  4. onnxtr/contrib/base.py +105 -0
  5. onnxtr/file_utils.py +33 -0
  6. onnxtr/io/__init__.py +5 -0
  7. onnxtr/io/elements.py +455 -0
  8. onnxtr/io/html.py +28 -0
  9. onnxtr/io/image.py +56 -0
  10. onnxtr/io/pdf.py +42 -0
  11. onnxtr/io/reader.py +85 -0
  12. onnxtr/models/__init__.py +4 -0
  13. onnxtr/models/_utils.py +141 -0
  14. onnxtr/models/builder.py +355 -0
  15. onnxtr/models/classification/__init__.py +2 -0
  16. onnxtr/models/classification/models/__init__.py +1 -0
  17. onnxtr/models/classification/models/mobilenet.py +120 -0
  18. onnxtr/models/classification/predictor/__init__.py +1 -0
  19. onnxtr/models/classification/predictor/base.py +57 -0
  20. onnxtr/models/classification/zoo.py +76 -0
  21. onnxtr/models/detection/__init__.py +2 -0
  22. onnxtr/models/detection/core.py +101 -0
  23. onnxtr/models/detection/models/__init__.py +3 -0
  24. onnxtr/models/detection/models/differentiable_binarization.py +159 -0
  25. onnxtr/models/detection/models/fast.py +160 -0
  26. onnxtr/models/detection/models/linknet.py +160 -0
  27. onnxtr/models/detection/postprocessor/__init__.py +0 -0
  28. onnxtr/models/detection/postprocessor/base.py +144 -0
  29. onnxtr/models/detection/predictor/__init__.py +1 -0
  30. onnxtr/models/detection/predictor/base.py +54 -0
  31. onnxtr/models/detection/zoo.py +73 -0
  32. onnxtr/models/engine.py +50 -0
  33. onnxtr/models/predictor/__init__.py +1 -0
  34. onnxtr/models/predictor/base.py +175 -0
  35. onnxtr/models/predictor/predictor.py +145 -0
  36. onnxtr/models/preprocessor/__init__.py +1 -0
  37. onnxtr/models/preprocessor/base.py +118 -0
  38. onnxtr/models/recognition/__init__.py +2 -0
  39. onnxtr/models/recognition/core.py +28 -0
  40. onnxtr/models/recognition/models/__init__.py +5 -0
  41. onnxtr/models/recognition/models/crnn.py +226 -0
  42. onnxtr/models/recognition/models/master.py +145 -0
  43. onnxtr/models/recognition/models/parseq.py +134 -0
  44. onnxtr/models/recognition/models/sar.py +134 -0
  45. onnxtr/models/recognition/models/vitstr.py +166 -0
  46. onnxtr/models/recognition/predictor/__init__.py +1 -0
  47. onnxtr/models/recognition/predictor/_utils.py +86 -0
  48. onnxtr/models/recognition/predictor/base.py +79 -0
  49. onnxtr/models/recognition/utils.py +89 -0
  50. onnxtr/models/recognition/zoo.py +69 -0
  51. onnxtr/models/zoo.py +114 -0
  52. onnxtr/transforms/__init__.py +1 -0
  53. onnxtr/transforms/base.py +112 -0
  54. onnxtr/utils/__init__.py +4 -0
  55. onnxtr/utils/common_types.py +18 -0
  56. onnxtr/utils/data.py +126 -0
  57. onnxtr/utils/fonts.py +41 -0
  58. onnxtr/utils/geometry.py +498 -0
  59. onnxtr/utils/multithreading.py +50 -0
  60. onnxtr/utils/reconstitution.py +70 -0
  61. onnxtr/utils/repr.py +64 -0
  62. onnxtr/utils/visualization.py +291 -0
  63. onnxtr/utils/vocabs.py +71 -0
  64. onnxtr/version.py +1 -0
  65. onnxtr-0.1.0.dist-info/LICENSE +201 -0
  66. onnxtr-0.1.0.dist-info/METADATA +481 -0
  67. onnxtr-0.1.0.dist-info/RECORD +70 -0
  68. onnxtr-0.1.0.dist-info/WHEEL +5 -0
  69. onnxtr-0.1.0.dist-info/top_level.txt +2 -0
  70. onnxtr-0.1.0.dist-info/zip-safe +1 -0
onnxtr/models/zoo.py ADDED
@@ -0,0 +1,114 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any
7
+
8
+ from .detection.zoo import detection_predictor
9
+ from .predictor import OCRPredictor
10
+ from .recognition.zoo import recognition_predictor
11
+
12
+ __all__ = ["ocr_predictor"]
13
+
14
+
15
+ def _predictor(
16
+ det_arch: Any,
17
+ reco_arch: Any,
18
+ assume_straight_pages: bool = True,
19
+ preserve_aspect_ratio: bool = True,
20
+ symmetric_pad: bool = True,
21
+ det_bs: int = 4,
22
+ reco_bs: int = 1024,
23
+ detect_orientation: bool = False,
24
+ straighten_pages: bool = False,
25
+ detect_language: bool = False,
26
+ **kwargs,
27
+ ) -> OCRPredictor:
28
+ # Detection
29
+ det_predictor = detection_predictor(
30
+ det_arch,
31
+ batch_size=det_bs,
32
+ assume_straight_pages=assume_straight_pages,
33
+ preserve_aspect_ratio=preserve_aspect_ratio,
34
+ symmetric_pad=symmetric_pad,
35
+ )
36
+
37
+ # Recognition
38
+ reco_predictor = recognition_predictor(
39
+ reco_arch,
40
+ batch_size=reco_bs,
41
+ )
42
+
43
+ return OCRPredictor(
44
+ det_predictor,
45
+ reco_predictor,
46
+ assume_straight_pages=assume_straight_pages,
47
+ preserve_aspect_ratio=preserve_aspect_ratio,
48
+ symmetric_pad=symmetric_pad,
49
+ detect_orientation=detect_orientation,
50
+ straighten_pages=straighten_pages,
51
+ detect_language=detect_language,
52
+ **kwargs,
53
+ )
54
+
55
+
56
+ def ocr_predictor(
57
+ det_arch: Any = "fast_base",
58
+ reco_arch: Any = "crnn_vgg16_bn",
59
+ assume_straight_pages: bool = True,
60
+ preserve_aspect_ratio: bool = True,
61
+ symmetric_pad: bool = True,
62
+ export_as_straight_boxes: bool = False,
63
+ detect_orientation: bool = False,
64
+ straighten_pages: bool = False,
65
+ detect_language: bool = False,
66
+ **kwargs: Any,
67
+ ) -> OCRPredictor:
68
+ """End-to-end OCR architecture using one model for localization, and another for text recognition.
69
+
70
+ >>> import numpy as np
71
+ >>> from onnxtr.models import ocr_predictor
72
+ >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn')
73
+ >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
74
+ >>> out = model([input_page])
75
+
76
+ Args:
77
+ ----
78
+ det_arch: name of the detection architecture or the model itself to use
79
+ (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
80
+ reco_arch: name of the recognition architecture or the model itself to use
81
+ (e.g. 'crnn_vgg16_bn', 'sar_resnet31')
82
+ assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
83
+ without rotated textual elements.
84
+ preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
85
+ running the detection model on it.
86
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right.
87
+ export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
88
+ (potentially rotated) as straight bounding boxes.
89
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
90
+ page. Doing so will slightly deteriorate the overall latency.
91
+ straighten_pages: if True, estimates the page general orientation
92
+ based on the segmentation map median line orientation.
93
+ Then, rotates page before passing it again to the deep learning detection module.
94
+ Doing so will improve performances for documents with page-uniform rotations.
95
+ detect_language: if True, the language prediction will be added to the predictions for each
96
+ page. Doing so will slightly deteriorate the overall latency.
97
+ kwargs: keyword args of `OCRPredictor`
98
+
99
+ Returns:
100
+ -------
101
+ OCR predictor
102
+ """
103
+ return _predictor(
104
+ det_arch,
105
+ reco_arch,
106
+ assume_straight_pages=assume_straight_pages,
107
+ preserve_aspect_ratio=preserve_aspect_ratio,
108
+ symmetric_pad=symmetric_pad,
109
+ export_as_straight_boxes=export_as_straight_boxes,
110
+ detect_orientation=detect_orientation,
111
+ straighten_pages=straighten_pages,
112
+ detect_language=detect_language,
113
+ **kwargs,
114
+ )
@@ -0,0 +1 @@
1
+ from .base import *
@@ -0,0 +1,112 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Tuple, Union
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ __all__ = ["Resize", "Normalize"]
12
+
13
+
14
+ class Resize:
15
+ """Resize the input image to the given size"""
16
+
17
+ def __init__(
18
+ self,
19
+ size: Union[int, Tuple[int, int]],
20
+ interpolation=cv2.INTER_LINEAR,
21
+ preserve_aspect_ratio: bool = False,
22
+ symmetric_pad: bool = False,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.size = size
26
+ self.interpolation = interpolation
27
+ self.preserve_aspect_ratio = preserve_aspect_ratio
28
+ self.symmetric_pad = symmetric_pad
29
+ self.output_size = size if isinstance(size, tuple) else (size, size)
30
+
31
+ if not isinstance(self.size, (int, tuple, list)):
32
+ raise AssertionError("size should be either a tuple, a list or an int")
33
+
34
+ def __call__(
35
+ self,
36
+ img: np.ndarray,
37
+ ) -> np.ndarray:
38
+ if img.ndim == 3:
39
+ h, w = img.shape[0:2]
40
+ else:
41
+ h, w = img.shape[1:3]
42
+ sh, sw = self.size if isinstance(self.size, tuple) else (self.size, self.size)
43
+
44
+ # Calculate aspect ratio of the image
45
+ aspect = w / h
46
+
47
+ # Compute scaling and padding sizes
48
+ if self.preserve_aspect_ratio:
49
+ if aspect > 1: # Horizontal image
50
+ new_w = sw
51
+ new_h = int(sw / aspect)
52
+ elif aspect < 1: # Vertical image
53
+ new_h = sh
54
+ new_w = int(sh * aspect)
55
+ else: # Square image
56
+ new_h, new_w = sh, sw
57
+
58
+ img_resized = cv2.resize(img, (new_w, new_h), interpolation=self.interpolation)
59
+
60
+ # Calculate padding
61
+ pad_top = max((sh - new_h) // 2, 0)
62
+ pad_bottom = max(sh - new_h - pad_top, 0)
63
+ pad_left = max((sw - new_w) // 2, 0)
64
+ pad_right = max(sw - new_w - pad_left, 0)
65
+
66
+ # Pad the image
67
+ img_resized = cv2.copyMakeBorder( # type: ignore[call-overload]
68
+ img_resized, pad_top, pad_bottom, pad_left, pad_right, borderType=cv2.BORDER_CONSTANT, value=0
69
+ )
70
+
71
+ # Ensure the image matches the target size by resizing it again if needed
72
+ img_resized = cv2.resize(img_resized, (sw, sh), interpolation=self.interpolation)
73
+ else:
74
+ # Resize the image without preserving aspect ratio
75
+ img_resized = cv2.resize(img, (sw, sh), interpolation=self.interpolation)
76
+
77
+ return img_resized
78
+
79
+ def __repr__(self) -> str:
80
+ interpolate_str = self.interpolation
81
+ _repr = f"output_size={self.size}, interpolation='{interpolate_str}'"
82
+ if self.preserve_aspect_ratio:
83
+ _repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}"
84
+ return f"{self.__class__.__name__}({_repr})"
85
+
86
+
87
+ class Normalize:
88
+ """Normalize the input image"""
89
+
90
+ def __init__(
91
+ self,
92
+ mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406),
93
+ std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225),
94
+ ) -> None:
95
+ self.mean = mean
96
+ self.std = std
97
+
98
+ if not isinstance(self.mean, (float, tuple, list)):
99
+ raise AssertionError("mean should be either a tuple, a list or a float")
100
+ if not isinstance(self.std, (float, tuple, list)):
101
+ raise AssertionError("std should be either a tuple, a list or a float")
102
+
103
+ def __call__(
104
+ self,
105
+ img: np.ndarray,
106
+ ) -> np.ndarray:
107
+ # Normalize image
108
+ return (img - np.array(self.mean).astype(img.dtype)) / np.array(self.std).astype(img.dtype)
109
+
110
+ def __repr__(self) -> str:
111
+ _repr = f"mean={self.mean}, std={self.std}"
112
+ return f"{self.__class__.__name__}({_repr})"
@@ -0,0 +1,4 @@
1
+ from .common_types import *
2
+ from .data import *
3
+ from .geometry import *
4
+ from .vocabs import *
@@ -0,0 +1,18 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from pathlib import Path
7
+ from typing import List, Tuple, Union
8
+
9
+ __all__ = ["Point2D", "BoundingBox", "Polygon4P", "Polygon", "Bbox"]
10
+
11
+
12
+ Point2D = Tuple[float, float]
13
+ BoundingBox = Tuple[Point2D, Point2D]
14
+ Polygon4P = Tuple[Point2D, Point2D, Point2D, Point2D]
15
+ Polygon = List[Point2D]
16
+ AbstractPath = Union[str, Path]
17
+ AbstractFile = Union[AbstractPath, bytes]
18
+ Bbox = Tuple[float, float, float, float]
onnxtr/utils/data.py ADDED
@@ -0,0 +1,126 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ # Adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/utils.py
7
+
8
+ import hashlib
9
+ import logging
10
+ import os
11
+ import re
12
+ import urllib
13
+ import urllib.error
14
+ import urllib.request
15
+ from pathlib import Path
16
+ from typing import Optional, Union
17
+
18
+ from tqdm.auto import tqdm
19
+
20
+ __all__ = ["download_from_url"]
21
+
22
+
23
+ # matches bfd8deac from resnet18-bfd8deac.ckpt
24
+ HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
25
+ USER_AGENT = "felixdittrich92/OnnxTR"
26
+
27
+
28
+ def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -> None:
29
+ with open(filename, "wb") as fh:
30
+ with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
31
+ with tqdm(total=response.length) as pbar:
32
+ for chunk in iter(lambda: response.read(chunk_size), ""):
33
+ if not chunk:
34
+ break
35
+ pbar.update(chunk_size)
36
+ fh.write(chunk)
37
+
38
+
39
+ def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool:
40
+ with open(file_path, "rb") as f:
41
+ sha_hash = hashlib.sha256(f.read()).hexdigest()
42
+
43
+ return sha_hash[: len(hash_prefix)] == hash_prefix
44
+
45
+
46
+ def download_from_url(
47
+ url: str,
48
+ file_name: Optional[str] = None,
49
+ hash_prefix: Optional[str] = None,
50
+ cache_dir: Optional[str] = None,
51
+ cache_subdir: Optional[str] = None,
52
+ ) -> Path:
53
+ """Download a file using its URL
54
+
55
+ >>> from onnxtr.models import download_from_url
56
+ >>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip")
57
+
58
+ Args:
59
+ ----
60
+ url: the URL of the file to download
61
+ file_name: optional name of the file once downloaded
62
+ hash_prefix: optional expected SHA256 hash of the file
63
+ cache_dir: cache directory
64
+ cache_subdir: subfolder to use in the cache
65
+
66
+ Returns:
67
+ -------
68
+ the location of the downloaded file
69
+
70
+ Note:
71
+ ----
72
+ You can change cache directory location by using `ONNXTR_CACHE_DIR` environment variable.
73
+ """
74
+ if not isinstance(file_name, str):
75
+ file_name = url.rpartition("/")[-1].split("&")[0]
76
+
77
+ cache_dir = (
78
+ str(os.environ.get("ONNXTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "onnxtr")))
79
+ if cache_dir is None
80
+ else cache_dir
81
+ )
82
+
83
+ # Check hash in file name
84
+ if hash_prefix is None:
85
+ r = HASH_REGEX.search(file_name)
86
+ hash_prefix = r.group(1) if r else None
87
+
88
+ folder_path = Path(cache_dir) if cache_subdir is None else Path(cache_dir, cache_subdir)
89
+ file_path = folder_path.joinpath(file_name)
90
+ # Check file existence
91
+ if file_path.is_file() and (hash_prefix is None or _check_integrity(file_path, hash_prefix)):
92
+ logging.info(f"Using downloaded & verified file: {file_path}")
93
+ return file_path
94
+
95
+ try:
96
+ # Create folder hierarchy
97
+ folder_path.mkdir(parents=True, exist_ok=True)
98
+ except OSError:
99
+ error_message = f"Failed creating cache direcotry at {folder_path}"
100
+ if os.environ.get("ONNXTR_CACHE_DIR", ""):
101
+ error_message += " using path from 'ONNXTR_CACHE_DIR' environment variable."
102
+ else:
103
+ error_message += (
104
+ ". You can change default cache directory using 'ONNXTR_CACHE_DIR' environment variable if needed."
105
+ )
106
+ logging.error(error_message)
107
+ raise
108
+ # Download the file
109
+ try:
110
+ print(f"Downloading {url} to {file_path}")
111
+ _urlretrieve(url, file_path)
112
+ except (urllib.error.URLError, IOError) as e: # pragma: no cover
113
+ if url[:5] == "https":
114
+ url = url.replace("https:", "http:")
115
+ print("Failed download. Trying https -> http instead." f" Downloading {url} to {file_path}")
116
+ _urlretrieve(url, file_path)
117
+ else:
118
+ raise e
119
+
120
+ # Remove corrupted files
121
+ if isinstance(hash_prefix, str) and not _check_integrity(file_path, hash_prefix): # pragma: no cover
122
+ # Remove file
123
+ os.remove(file_path)
124
+ raise ValueError(f"corrupted download, the hash of {url} does not match its expected value")
125
+
126
+ return file_path
onnxtr/utils/fonts.py ADDED
@@ -0,0 +1,41 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import logging
7
+ import platform
8
+ from typing import Optional
9
+
10
+ from PIL import ImageFont
11
+
12
+ __all__ = ["get_font"]
13
+
14
+
15
+ def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont:
16
+ """Resolves a compatible ImageFont for the system
17
+
18
+ Args:
19
+ ----
20
+ font_family: the font family to use
21
+ font_size: the size of the font upon rendering
22
+
23
+ Returns:
24
+ -------
25
+ the Pillow font
26
+ """
27
+ # Font selection
28
+ if font_family is None:
29
+ try:
30
+ font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size)
31
+ except OSError: # pragma: no cover
32
+ font = ImageFont.load_default()
33
+ logging.warning(
34
+ "unable to load recommended font family. Loading default PIL font,"
35
+ "font size issues may be expected."
36
+ "To prevent this, it is recommended to specify the value of 'font_family'."
37
+ )
38
+ else: # pragma: no cover
39
+ font = ImageFont.truetype(font_family, font_size)
40
+
41
+ return font